In [None]:
%%bash
echo "Installing requirements .."
pip install pandas==0.22.0 quandl pandas_datareader alpha_vantage matplotlib plotly sklearn scipy fix_yahoo_finance statsmodels beautifulsoup4 > /dev/null 2>&1
# NOTE: we use pandas 0.22 for now since pandas_datareader don't support 0.23 yet
echo "Done"

In [None]:
import os
import datetime
import numbers
import subprocess
import uuid
import string
import json 
import requests
from io import StringIO
import re
import math

import pandas as pd
import numpy as np
import sklearn as sk
from sklearn import linear_model

import quandl
quandl.ApiConfig.api_key = "9nrUn7Sm1SdoeLdQGQB-"

import pandas_datareader
from pandas_datareader import data as pdr
import fix_yahoo_finance as yf
yf.pdr_override() # <== that's all it takes :-)
import alpha_vantage
from alpha_vantage.timeseries import TimeSeries
from alpha_vantage.cryptocurrencies import CryptoCurrencies

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (20.0, 10.0) # Make plots bigger

import plotly.offline as py
import plotly.graph_objs as go
import plotly.graph_objs.layout as gol
py.init_notebook_mode()

from pathlib import Path
from bs4 import BeautifulSoup


In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))
pd.set_option('display.max_rows', 5000)
pd.set_option('display.max_columns', 500)


In [None]:
def pd_from_dict(d):
    return pd.DataFrame.from_dict(d, orient='index').T.sort_index()

In [None]:
# (hack) Global configs
conf_cache_disk = True
conf_cache_memory = True
conf_cache_fails = True

class GetConf:
    def __init__(self, splitAdj, divAdj, cache, mode, source, secondary):
        self.splitAdj = splitAdj
        self.divAdj = divAdj
        self.cache = cache
        self.mode = mode
        self.source = source
        self.secondary = secondary

In [None]:
if not "fetchCache" in globals():
    fetchCache = {}
    
if not "symbols_mem_cache" in globals():
    symbols_mem_cache = {}
    

In [None]:
class Symbol:
    def __init__(self, fullname):
        self.fullname = fullname
        parts = fullname.split("!")
        if len(parts) == 2:
            fullname = parts[0]
            self.currency = parts[1]
        else:
            self.currency = ""
        parts = fullname.split("@")
        self.name = parts[0]
        if len(parts) == 2:
            self.source = parts[1]
        else:
            self.source = ""
       
    def __str__(self):
        return self.fullname

In [None]:
import scipy.optimize
from datetime import datetime as dt
def xnpv(rate, values, dates):
    '''Equivalent of Excel's XNPV function.

    >>> from datetime import date
    >>> dates = [date(2010, 12, 29), date(2012, 1, 25), date(2012, 3, 8)]
    >>> values = [-10000, 20, 10100]
    >>> xnpv(0.1, values, dates)
    -966.4345...
    '''
    if rate <= -1.0:
        return float('inf')
    d0 = dates[0]    # or min(dates)
    return sum([ vi / (1.0 + rate)**((di - d0).days / 365.0) for vi, di in zip(values, dates)])


def xirr(values, dates):
    '''Equivalent of Excel's XIRR function.

    >>> from datetime import date
    >>> dates = [date(2010, 12, 29), date(2012, 1, 25), date(2012, 3, 8)]
    >>> values = [-10000, 20, 10100]
    >>> xirr(values, dates)
    0.0100612...
    '''
    # we prefer to try brentq first as newton keeps outputting tolerance warnings
    try:
        return scipy.optimize.brentq(lambda r: xnpv(r, values, dates), -1.0, 1e10)
        #return scipy.optimize.newton(lambda r: xnpv(r, values, dates), 0.0, tol=0.0002)
    except RuntimeError:    # Failed to converge?
        return scipy.optimize.newton(lambda r: xnpv(r, values, dates), 0.0, tol=0.0002)
        #return scipy.optimize.brentq(lambda r: xnpv(r, values, dates), -1.0, 1e10)

#xirr([-100, 100, 200], [dt(2000, 1, 1), dt(2001, 1, 1), dt(2002, 1, 1)])

In [None]:
def curr_price(symbol):
    if symbol in ignoredAssets: return 0
    return get(symbol)[-1]

#def getForex(fromCur, toCur):
#    if fromCur == toCur: return 1
#    if toCur == "USD":
#        return get(fromCur + "=X", "Y")
#    if fromCur == "USD":
#        return get(toCur + "=X", "Y").map(lambda x: 1.0/x)

def getForex(fromCur, toCur):
    if fromCur == toCur: return 1
    #tmp = get(fromCur + toCur + "@CUR").s
    tmp = get(fromCur + "/" + toCur + "@IC").s
    tmp = tmp.reindex(pd.date_range(start=tmp.index[0], end=tmp.index[-1]))
    tmp = tmp.fillna(method="ffill")
    return tmp
    #return wrap(tmp, fromCur+toCur)

def convertSeries(s, fromCur, toCur):
    if fromCur == toCur: return s
    rate = getForex(fromCur, toCur)
    s = (s*rate).dropna()
    return s
    
def convertToday(value, fromCur, toCur):
    if fromCur == toCur: return value
    return value * getForex(fromCur, toCur)[-1]


In [None]:
def getName(s):
    if isinstance(s, str):
        return s
    return s.name

def toSymbol(sym):
    if isinstance(sym, Symbol):
        return sym
    if isinstance(sym, str):
        return Symbol(sym)
    assert False, "invalid type for Symbol: " + str(type(sym)) + ", " + str(sym)

class DataSource:
    
    def __init__(self, source):
        self.source = source
    
    def fetch(self, symbol, conf):
        pass
    
    def process(self, symbol, df, conf):
        pass
    
    def get(self, symbol, conf):
        global conf_cache_disk, conf_cache_memory, conf_cache_fails

        df = None

        # get from mem cache
        if conf.cache and conf_cache_memory:
            if symbol.fullname in symbols_mem_cache:
                df = symbols_mem_cache[symbol.fullname]
        
        # get from disk cache
        if df is None and conf.cache and conf_cache_disk:
            df = cache_get(symbol, self.source)
        
        # attempt to fetch the symbol
        if df is None:
            failpath = cache_file(symbol, self.source) + "._FAIL_"
            if os.path.isfile(failpath):
                mtime = datetime.datetime.fromtimestamp(os.path.getmtime(failpath))
                diff = datetime.datetime.now() - mtime
                if conf_cache_fails and diff.total_seconds() <= 24 * 3600:
                    raise Exception("Fetching has previously failed for {0}, will try again later".format(symbol))

            try:
                # Attempt to actually fetch the symbol
                if df is None:
                    print("Fetching %s from %s .. " % (symbol, self.source), end="")
                    df = self.fetch(symbol, conf)
                    print("DONE")
                if df is None:
                    print("FAILED")
                    raise Exception("Failed to fetch symbol: " + str(symbol) + " from " + self.source)
                if len(df) == 0:
                    print("FAILED")
                    raise Exception("Symbol fetched but is empty: " + str(symbol) + " from " + self.source)
            except Exception as e:
                # save a note that we failed
                Path(failpath).touch()
                raise Exception from e
        
        # write to disk cache
        cache_set(symbol, self.source, df)
        # write to mem cache
        symbols_mem_cache[symbol.fullname] = df
        
        if conf.mode == "raw":
            res = df
        else:
            res = self.process(symbol, df, conf)
        return res.sort_index()

fred_forex_codes = """
AUD	DEXUSAL
BRL	DEXBZUS
GBP	DEXUSUK
CAD	DEXCAUS
CNY	DEXCHUS
DKK	DEXDNUS
EUR	DEXUSEU
HKD	DEXHKUS
INR	DEXINUS
JPY	DEXJPUS
MYR	DEXMAUS
MXN	DEXMXUS
TWD	DEXTAUS
NOK	DEXNOUS
SGD	DEXSIUS
ZAR	DEXSFUS
KRW	DEXKOUS
LKR	DEXSLUS
SEK	DEXSDUS
CHF	DEXSZUS
VEF	DEXVZUS
"""

boe_forex_codes = """
AUD	XUDLADD
CAD	XUDLCDD
CNY	XUDLBK73
CZK	XUDLBK27
DKK	XUDLDKD
HKD	XUDLHDD
HUF	XUDLBK35
INR	XUDLBK64
NIS	XUDLBK65
JPY	XUDLJYD
LTL	XUDLBK38
MYR	XUDLBK66
NZD	XUDLNDD
NOK	XUDLNKD
PLN	XUDLBK49
GBP	XUDLGBD
RUB	XUDLBK69
SAR	XUDLSRD
SGD	XUDLSGD
ZAR	XUDLZRD
KRW	XUDLBK74
SEK	XUDLSKD
CHF	XUDLSFD
TWD	XUDLTWD
THB	XUDLBK72
TRY	XUDLBK75
"""

# https://blog.quandl.com/api-for-currency-data
class ForexDataSource(DataSource):
    def __init__(self, source):
        self.fred_code_map = dict([s.split("\t") for s in fred_forex_codes.split("\n")[1:-1]])
        self.boe_code_map = dict([s.split("\t") for s in boe_forex_codes.split("\n")[1:-1]])
        self.boe_code_map["ILS"] = self.boe_code_map["NIS"]
        super().__init__(source)
    
    def fetch(self, symbol, conf):
        assert len(symbol.name) == 6
        _from = symbol.name[:3]
        _to = symbol.name[3:]
        if _to != "USD" and _from != "USD":
            raise Exception("Can only convert to/from USD")
        invert = _from == "USD"
        curr = _to if invert else _from
        
        div100 = 1
        if curr == "GBC":
            div100 = 100
            curr = "GBP"
        
        if curr in self.fred_code_map:
            code = self.fred_code_map[curr]
            df = quandl.get("FRED/" + code)
            if code.endswith("US") != invert: # some of the FRED currencies are inverted vs the US dollar, argh..
                df = df.apply(lambda x: 1.0/x)
            return df / div100

        if curr in self.boe_code_map:
            code = self.boe_code_map[curr]
            df = quandl.get("BOE/" + code)
            if not invert: # not sure if some of BEO currencies are NOT inverted vs USD, checked a few and they weren't
                df = df.apply(lambda x: 1.0/x)
            return df / div100

        raise Exception("Currency pair is not supported: " + symbol.name)
        
    def process(self, symbol, df, conf):
        return df.iloc[:, 0]
      
# https://github.com/ranaroussi/fix-yahoo-finance
class YahooDataSource(DataSource):
    def fetch(self, symbol, conf):
        return pdr.get_data_yahoo(symbol.name, progress=False, actions=True)

    def process(self, symbol, df, conf):
        if conf.mode == "TR":
            assert conf.splitAdj and conf.divAdj
            return df["Adj Close"]
        elif conf.mode == "PR":
            # Yahoo "Close" data is split adjusted. 
            # We find the unadjusted data using the splits data
            splitMul = df["Stock Splits"][::-1].cumprod().shift().fillna(method="bfill")
            return df["Close"] / splitMul        
        elif conf.mode == "divs":
            return df["Dividends"]
        else:
            raise Exception("Unsupported mode [" + conf.mode + "] for YahooDataSource")

class QuandlDataSource(DataSource):
    def fetch(self, symbol, conf):
        return quandl.get(symbol.name)

    def process(self, symbol, df, conf):
        if "Close" in df.columns:
            return df["Close"]
        return df.iloc[:, 0]

    
class GoogleDataSource(DataSource):
    def fetch(self, symbol, conf):
        return pandas_datareader.data.DataReader(symbol.name, 'google')

    def process(self, symbol, df, conf):
        return df["Close"]
    
AV_API_KEY = 'BB18'
class AlphaVantageDataSource(DataSource):

    def adjustSplits(self, price, splits):
        r = splits.cumprod()
        return price * r
    
    # AV sometimes have duplicate split multiplers, we only use the last one 
    def fixAVSplits(self, df):
        df = df.sort_index()
        split = df["8. split coefficient"]
        count = 0
        for t, s in list(split.items())[::-1]:
            if s == 1.0:
                count = 0
                continue
            count += 1
            if count == 1:
                continue
            if count > 1:
                split[t] = 1.0
        df["8. split coefficient"] = split
        return df

    def fetch(self, symbol, conf):
        ts = TimeSeries(key=AV_API_KEY, output_format='pandas')
        df, meta_data = ts.get_daily_adjusted(symbol.name, outputsize="full")
        df.index = pd.to_datetime(df.index, format="%Y-%m-%d")
        df = self.fixAVSplits(df)
        return df

    def process(self, symbol, df, conf):
        if conf.mode == "TR":
            return df["5. adjusted close"]
        elif conf.mode == "PR":
            return self.adjustSplits(df["4. close"], df['8. split coefficient'])
        elif conf.mode == "divs":
            return df["7. dividend amount"]
        else:
            raise Exception("Unsupported mode [" + conf.mode + "] for AlphaVantageDataSource")
        
class AlphaVantageCryptoDataSource(DataSource):

    def fetch(self, symbol, conf):
        cc = CryptoCurrencies(key=AV_API_KEY, output_format='pandas')
        df, meta_data = cc.get_digital_currency_daily(symbol=symbol.name, market='USD')
        df.index = pd.to_datetime(df.index, format="%Y-%m-%d")
        return df

    def process(self, symbol, df, conf):
        return df['4a. close (USD)']

class CryptoCompareDataSource(DataSource):
    def fetch(self, symbol, conf):
        url = "https://min-api.cryptocompare.com/data/histoday?fsym=__sym__&tsym=USD&limit=600000&aggregate=1&e=CCCAGG"
        d = json.loads(requests.get(url.replace("__sym__", symbol.name)).text)
        df = pd.DataFrame(d["Data"])
        if len(df) == 0:
            return None
        df["time"] = pd.to_datetime(df.time, unit="s")
        df.set_index("time", inplace=True)
        return df

    def process(self, symbol, df, conf):
        return df.close

# NOTE: data is SPLIT adjusted, but has no dividends and is NOT DIVIDEND adjusted 
# NOTE: it has data all the way to the start, but returned result is capped in length for ~20 years
#       and results are trimmed from the END, not from the start. TBD to handle this properly.
#       for now we start at 1.1.2000
class InvestingComDataSource(DataSource):

    def getUrl(self, symbol):
        symbol = symbol.name
        data = {
            'search_text': symbol,
            'term': symbol, 
            'country_id': '0',
            'tab_id': 'All'
        }
        headers = {
                    'Origin': 'https://www.investing.com',
                    'Accept-Encoding': 'gzip, deflate, br',
                    'Accept-Language': 'en-US,en;q=0.9,he;q=0.8',
                    'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.117 Safari/537.36',
                    'Content-Type': 'application/x-www-form-urlencoded',
                    'Accept': 'application/json, text/javascript, */*; q=0.01',
                    'Referer': 'https://www.investing.com/search?q=' + symbol,
                    'X-Requested-With': 'XMLHttpRequest',
                    'Connection': 'keep-alive'    
                }
        r = requests.post("https://www.investing.com/search/service/search", data=data, headers=headers)
        res = r.text
        res = json.loads(res)
        return res["All"][0]["link"]
    
    def getCodes(self, url):
        url = "https://www.investing.com" + url + "-historical-data"
        
        headers = {
                    'Origin': 'https://www.investing.com',
                    'Accept-Encoding': 'gzip, deflate, br',
                    'Accept-Language': 'en-US,en;q=0.9,he;q=0.8',
                    'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.117 Safari/537.36',
                    'Content-Type': 'application/x-www-form-urlencoded',
                    'Accept': 'application/json, text/javascript, */*; q=0.01',
                    'Referer': 'https://www.investing.com/',
                    'X-Requested-With': 'XMLHttpRequest',
                    'Connection': 'keep-alive'    
                }
        r = requests.get(url,headers=headers)
        text = r.text
        
        m = re.search("smlId:\s+(\d+)", text)
        smlId = m.group(1)
        
        m = re.search("pairId:\s+(\d+)", text)
        pairId = m.group(1)
        
        return pairId, smlId
    
    def getHtml(self, pairId, smlId):
        data = [
            'curr_id=' + pairId,
            'smlID=' + smlId,
            'header=',
            'st_date=01%2F01%2F2000',
            'end_date=01%2F01%2F2100',
            'interval_sec=Daily',
            'sort_col=date',
            'sort_ord=DESC', 
            'action=historical_data'
        ]
        data = "&".join(data)
        headers = {
            'Origin': 'https://www.investing.com',
            'Accept-Encoding': 'gzip, deflate, br',
            'Accept-Language': 'en-US,en;q=0.9,he;q=0.8',
            'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.117 Safari/537.36',
            'Content-Type': 'application/x-www-form-urlencoded',
            'Accept': 'text/plain, */*; q=0.01',
            'Referer': 'https://www.investing.com/',
            'X-Requested-With': 'XMLHttpRequest',
            'Connection': 'keep-alive'    
        }
        r = requests.post("https://www.investing.com/instruments/HistoricalDataAjax", data=data, headers=headers)
        return r.text
    
    def fetch(self, symbol, conf):
        symbolUrl = self.getUrl(symbol)
        
        pairId, smlId = self.getCodes(symbolUrl)
        
        html = self.getHtml(pairId, smlId)
        #print(html)
        parsed_html = BeautifulSoup(html, "lxml")
        df = pd.DataFrame(columns=["date", "price"])
        
        for i, tr in enumerate(parsed_html.find_all("tr")[1:]): # skip header
            data = [x.get("data-real-value") for x in tr.find_all("td")]
            if len(data) == 0 or data[0] is None:
                continue
            date = datetime.datetime.utcfromtimestamp(int(data[0]))
            close = float(data[1].replace(",", ""))
            #open = data[2]
            #high = data[3]
            #low = data[4]
            #volume = data[5]
            df.loc[i] = [date, close]
            
        df = df.set_index("date")
        return df

    def process(self, symbol, df, conf):
        return df['price']

import time    
class JustEtfDataSource(DataSource):


    def parseDate(self, s):
        s = s.strip(" {x:Date.UTC()")
        p = [int(x) for x in s.split(",")]
        dt = datetime.datetime(p[0], p[1] + 1, p[2])
        return dt

    def parseDividends(self, x):
        x = re.split("data: \[", x)[1]
        x = re.split("\]\^,", x)[0]
        data = []
        for x in re.split("\},\{", x):
            p = re.split(", events: \{click: function\(\) \{  \}\}, title: 'D', text: 'Dividend ", x)
            dt = self.parseDate(p[0])
            p = p[1].strip("',id: }").split(" ")
            currency = p[0]
            value = float(p[1])
            data.append((dt, value))
        return pd.DataFrame(data, columns=['dt', 'divs']).set_index("dt")

    def parsePrice(self, s):
        data = []
        line = s
        t = "data: ["
        line = line[line.find(t) + len(t):]
        t = "^]^"
        line = line[:line.find(t)]
        #print(line)
        parts = line.split("^")
        for p in parts:
            p = p.strip("[],")
            p = p.split(")")
            value = float(p[1].replace(",", ""))
            dateStr = p[0].split("(")[1]
            p = [int(x) for x in dateStr.split(",")]
            dt = datetime.datetime(p[0], p[1] + 1, p[2])
            data.append((dt, value))
            #print(dt, value)
        df = pd.DataFrame(data, columns=['dt', 'price']).set_index("dt")
        return df

    def parseRawText(self, s):
        x = re.split("addSeries", s)
        df = self.parsePrice(x[1])
        divs = self.parseDividends(x[2])
        df["divs"] = divs["divs"]
        return df

    def getIsin(self, symbol):
        symbolName = symbol.name
        data = {
            'draw': '1',
            'start': '0', 
            'length': '25', 
            'search[regex]': 'false', 
            'lang': 'en', 
            'country': 'GB', 
            'universeType': 'private', 
            'etfsParams': 'query=' + symbolName, 
            'groupField': 'index', 
        }
        headers = {
                    'Accept-Encoding': 'gzip, deflate, br',
                }
        session = requests.Session()
        
        
        r = session.get("https://www.justetf.com/en/etf-profile.html?tab=chart&isin=IE00B5L65R35", headers=headers)
        
        r = session.post("https://www.justetf.com/servlet/etfs-table", data=data, headers=headers)
        res = r.text
        
        res = json.loads(res)
        for d in res["data"]:
            if d["ticker"] == symbolName:
                return (d["isin"], session)
        raise Exception("Symbol not found in source: " + str(symbol))
    
    def getData(self, isin, session, conf, raw=False):
        if not session:
            session = requests.Session()
            
        headers = {
                    'Accept-Encoding': 'gzip, deflate, br',
                }
        
        url3 = "https://www.justetf.com/uk/etf-profile.html?groupField=index&from=search&isin=" + isin + "&tab=chart"
        r = session.get(url3, headers=headers)

        r = session.get("https://www.justetf.com/sw.js", headers=headers)
        text = r.text

        headers = {
            'accept-encoding': 'gzip, deflate, br',
            'accept-language': 'en-US,en;q=0.9,he;q=0.8',
            'wicket-focusedelementid': 'id1b',
            'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.62 Safari/537.36',
            'accept': 'text/xml',
            'referer': 'https://www.justetf.com/uk/etf-profile.html?groupField=index&from=search&isin=IE00B5L65R35&tab=chart',
            'authority': 'www.justetf.com',
            'wicket-ajax': 'true'
        }

        if conf.mode == "PR":
            headers["wicket-focusedelementid"] = "includePayment"
            url = "https://www.justetf.com/uk/?wicket:interface=:2:tabs:panel:chart:optionsPanel:selectContainer:includePaymentContainer:includePayment::IBehaviorListener:0:1&random=0.8852086768595453"
            r = session.post(url, headers=headers)
            text = r.text
            headers["wicket-focusedelementid"] = "id1b"
        
        url = "https://www.justetf.com/en/?wicket:interface=:0:tabs:panel:chart:dates:ptl_3y::IBehaviorListener:0:1&random=0.2525050377785838"
        r = session.get(url, headers=headers)
        text = r.text

        
        # PRICE (instead of percent change)
        data = { 'tabs:panel:chart:optionsPanel:selectContainer:valueType': 'market_value' }
        url = "https://www.justetf.com/en/?wicket:interface=:0:tabs:panel:chart:optionsPanel:selectContainer:valueType::IBehaviorListener:0:1&random=0.7560635418741075"
        r = session.post(url, headers=headers, data=data)
        text = r.text
        
        # CURRENCY
        #data = { 'tabs:panel:chart:optionsPanel:selectContainer:currencies': '3' }
        #url = "https://www.justetf.com/en/?wicket:interface=:0:tabs:panel:chart:optionsPanel:selectContainer:currencies::IBehaviorListener:0:1&random=0.8898086171718949"
        #r = session.post(url, headers=headers, data=data)
        #text = r.text
        
        
        
        url = "https://www.justetf.com/en/?wicket:interface=:0:tabs:panel:chart:dates:ptl_max::IBehaviorListener:0:1&random=0.2525050377785838"
        #url = "https://www.justetf.com/uk/?wicket:interface=:3:tabs:panel:chart:dates:ptl_max::IBehaviorListener:0:1"
        
        #plain_cookie = 'locale_=en_GB; universeCountry_=GB; universeDisclaimerAccepted_=false; JSESSIONID=5C4770C8CE62E823C17E292486D04112.production01; AWSALB=Wy2YQ+nfXWR+lTtsGly/hBDFD5pCCtYo/VxE0lIXBPlA/SdQDbRxhg+0q2E8UybYawqQiy3/1m2Bs4xvN8yFW3cs/2zy8385MuhGGCN/FUwnstSvbL7T8rfcV03k'
        #cj = requests.utils.cookiejar_from_dict(dict(p.split('=') for p in plain_cookie.split('; ')))
        #session.cookies = cj
        
        r = session.get(url, headers=headers)
        text = r.text
        #print(text)
        if raw:
            return text
        
        return self.parseRawText(text)
        
        
    
    def fetch(self, symbol, conf):
        return self.getData(symbol.name, None, conf)

    def process(self, symbol, df, conf):
        if conf.mode == "TR":
            return df["price"]
        elif conf.mode == "PR":
            raise Exception("Unsupported mode [" + conf.mode + "] for JustEtfDataSource")
        elif conf.mode == "divs":
            return df["divs"]
        else:
            raise Exception("Unsupported mode [" + conf.mode + "] for JustEtfDataSource")
        
        return df['price']
    
#x = JustEtfDataSource("XXX")
#isin, session = x.getIsin(Symbol("ERNS"))
#t = x.getData(isin, session)

#conf = lambda x: x
#conf.mode = "TR"
#t = x.getData("IE00B5L65R35", None, conf, True)

class BloombergDataSource(DataSource):
    def fetch(self, symbol, conf):
        url = "https://www.bloomberg.com/markets/api/bulk-time-series/price/__sym__?timeFrame=5_YEAR"
        sym = symbol.name.replace(";", ":")
        d = json.loads(requests.get(url.replace("__sym__", sym)).text)
        #print(d)
        df = pd.DataFrame(d[0]["price"])
        if len(df) == 0:
            return None
        df["date"] = pd.to_datetime(df.date, format="%Y-%m-%d")
        df.set_index("date", inplace=True)
        return df

    def process(self, symbol, df, conf):
        return df.value




In [None]:
# fetching data

if not "Wrapper" in locals():
    class Wrapper(object):

        def __init__(self, s):
            #self.s = s
            object.__setattr__(self, "s", s)

        def __getattr__(self, name):
            attr = self.s.__getattribute__(name)

            if hasattr(attr, '__call__'):
                def newfunc(*args, **kwargs):
                    result = attr(*args, **kwargs)
                    if type(result) is pd.Series:
                        result = Wrapper(result)
                    return result
                return newfunc

            if type(attr) is pd.Series:
                attr = Wrapper(attr)
            return attr

        def __setattr__(self, name, value):
            self.s.__setattr__(name, value)

        def __getitem__(self, item):
             return wrap(self.s.__getitem__(item), self.s.name)

#         def __truediv__(self, other):
#             divisor = other
#             if type(other) is Wrapper:
#                 divisor = other.s
#             series = self.s / divisor
#             name = self.name
#             if type(other) is Wrapper:
#                 name = self.s.name + " / " + other.s.name
#             return wrap(series, name)

        def __truediv__(self, other):
            return Wrapper.doop(self, other, "/", lambda x, y: x / y)
        def __rtruediv__(self, other):
            return Wrapper.doop(self, other, "/", lambda x, y: x / y, right=True)
        
        def doop(self, other, opname, opLambda, right=False):
            divisor = other
            if type(other) is Wrapper:
                divisor = other.s
            if right:
                series = opLambda(divisor, self.s)
            else:
                series = opLambda(self.s, divisor)
            name = self.name
            if type(other) is Wrapper:
                if right:
                    name = other.s.name + " " + opname + " " + self.s.name
                else:
                    name = self.s.name + " " + opname + " " + other.s.name
            return wrap(series, name)

        def __sub__(self, other):
            return Wrapper.doop(self, other, "-", lambda x, y: x - y)
        #def __rsub__(self, other):
        #    return Wrapper.doop(self, other, "-", lambda x, y: x - y, right=True)

        def __mul__(self, other):
            return Wrapper.doop(self, other, "*", lambda x, y: x * y)
        def __rmul__(self, other):
            return Wrapper.doop(self, other, "*", lambda x, y: x * y, right=True)

def wrap(s, name=""):
    name = name or s.name
    if not name:
        raise Exception("no name")
    if isinstance(s, pd.Series):
        s = Wrapper(s)
        s.name = name
    elif isinstance(s, Wrapper):
        s.name = name
    return s

name = wrap # syn-sugar

    
data_sources = {
    
    "B": BloombergDataSource("B"),
    "JT": JustEtfDataSource("JT"),
    "Y": YahooDataSource("Y"),
    "IC": InvestingComDataSource("IC"),
    "Q": QuandlDataSource("Q"),
    "AV": AlphaVantageDataSource("AV"),
    "CC": CryptoCompareDataSource("CC"),
    "CCAV": AlphaVantageCryptoDataSource("CCAV"),
    "CUR": ForexDataSource("CUR"),
    "G": GoogleDataSource("G")
               }

def getFrom(symbol, conf):
    # special handling for forex
    # if a match, if will recurse and return here with XXXUSD@CUR
    if len(symbol.name) == 6 and not symbol.source:
        parts = symbol.name[:3], symbol.name[3:]
        if parts[0] == "USD" or parts[1] == "USD":
            return wrap(getForex(parts[0], parts[1]), symbol.name)
    
    source = symbol.source or conf.source or "AV"
    if not source in data_sources:
        raise Exception("Unsupported source: " + source)
    if not conf.secondary:
        return data_sources[source].get(symbol, conf)
    try:
        return data_sources[source].get(symbol, conf)
    except Exception as e:
        # if the source wasn't explicitly stated, try from secondary
        if not symbol.source and not conf.source:
            print("Failed to fetch {0} from {1}, trying from {2} .. ".format(symbol, source, conf.secondary), end="")
            res = data_sources[conf.secondary].get(symbol, conf)
            print("DONE")
            return res
        else:
            raise e

def format_filename(s):
    valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
    filename = ''.join(c for c in s if c in valid_chars)
    filename = filename.replace(' ','_')
    return filename
    
def cache_file(symbol, source):
    filepath = os.path.join("symbols", source, format_filename(symbol.name))
    dirpath = os.path.dirname(filepath)
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)
    return filepath

def cache_get(symbol, source):
    filepath = cache_file(symbol, source)
    if os.path.exists(filepath):
        #res = pd.read_csv(filepath, squeeze=True, names=["date", "value"], index_col="date")
        res = pd.read_csv(filepath, squeeze=False, index_col="date")
        res.index = pd.to_datetime(res.index, format="%Y-%m-%d")
        return res
    return None

def cache_set(symbol, source, s):
    filepath = cache_file(symbol, source)
    s.to_csv(filepath, date_format="%Y-%m-%d", index_label="date")


def get_port(d, name, getArgs):
    if isinstance(d, str):
        res = parse_portfolio_def(d)
        if not res:
            raise Exception("Invalid portfolio definition: " + d)
        d = res
    if not isinstance(d, dict):
        raise Exception("Portfolio definition must be str or dict, was: " + type(d))        
    if isinstance(name, dict):
        name = "|".join([getName(k)+":"+str(v) for k, v in d.items()])
    df = pd.DataFrame(logret(get(k, **getArgs).s)*v/100 for k,v in d.items()).T.dropna()
    res = Wrapper(i_logret(df.sum(axis=1)))
    res.name = name
    return res

def parse_portfolio_def(s):
    if isinstance(s, dict):
        return s
    if not isinstance(s, str):
        return None
    d = {}
    parts = s.split("|")
    for p in parts:
        parts2 = p.split(":")
        if len(parts2) != 2:
            return None
        d[parts2[0]] = float(parts2[1])
    return d

def getNtr(s, getArgs):
    mode = getArgs["mode"]
    getArgs["mode"] = "PR"
    pr = get(s, **getArgs)
    getArgs["mode"] = "divs"
    divs = get(s, **getArgs)
    getArgs["mode"] = mode
    
    tax = 0.25
    divs = divs * (1-tax)
    divs = divs / pr
    divs = divs.fillna(0)
    r = 1 + divs.s
    r = r.cumprod()
    ntr = (pr * r).dropna()
    #ntr = wrap(ntr, s.name + " NTR")
    ntr = wrap(ntr, s.name)
    return ntr


def get(symbol, source=None, cache=True, splitAdj=True, divAdj=True, adj=None, mode="TR", secondary="Y", fillDays=True, despike=False, trim=False):
    getArgs = {}
    getArgs["source"] = source
    getArgs["cache"] = cache
    getArgs["splitAdj"] = splitAdj
    getArgs["divAdj"] = divAdj
    getArgs["adj"] = adj
    getArgs["mode"] = mode
    getArgs["secondary"] = secondary
    getArgs["fillDays"] = fillDays
    getArgs["despike"] = despike
    getArgs["trim"] = trim
    
    if isinstance(symbol, list):
        lst = symbol
        lst = [get(s, **getArgs) for s in lst]
        if despike:
            lst = [globals()["despike"](s) for s in lst]
        if trim:
            lst = doTrim(lst)
        return lst
    
    # support for yield period tuples, e.g.: (SPY, 4)
    if isinstance(symbol, tuple) and len(symbol) == 2:
        symbol, _ = symbol
    
    if isinstance(symbol, Wrapper) or isinstance(symbol, pd.Series):
        return symbol
    if "ignoredAssets" in globals() and ignoredAssets and symbol in ignoredAssets:
        return wrap(pd.Series(), "<empty>")
    
    # special handing for composite portfolios
    port = parse_portfolio_def(symbol)
    if port:
        return get_port(port, symbol, getArgs)
    
    symbol = toSymbol(symbol)
    
    if mode == "NTR":
        return getNtr(symbol, getArgs)
    
    if adj == False:
        splitAdj = False
        divAdj = False

    s = getFrom(symbol, GetConf(splitAdj, divAdj, cache, mode, source, secondary))
    
    s = s[s>0] # clean up broken yahoo data, etc ..
    
    if fillDays and mode != "divs" and mode != "raw":
        s = s.reindex(pd.date_range(start=s.index[0], end=s.index[-1]))
        s = s.interpolate()
    
    return wrap(s, symbol.fullname)


In [None]:
#     def __getattribute__(self,name):
#         s = object.__getattribute__(self, "s")
#         if name == "s":
#             return s
        
#         attr = s.__getattribute__(name)
        
#         if hasattr(attr, '__call__'):
#             def newfunc(*args, **kwargs):
#                 result = attr(*args, **kwargs)
#                 if type(result) is pd.Series:
#                     result = Wrapper(result)
#                 return result
#             return newfunc
        
#         if type(attr) is pd.Series:
#             attr = Wrapper(attr)
#         return attr
    


In [None]:
# plotting

from plotly.graph_objs import *

def createVerticalLine(xval):
    shape = {
            'type': 'line',
            #'xref': 'x',
            'x0': xval,
            'x1': xval,
            'yref': 'paper',
            'y0': 0,
            'y1': 1,
            #'fillcolor': 'blue',
            'opacity': 1,
            'line': {
                'width': 1,
                'color': 'red'
            }
        }
    return shape
    
def createHorizontalLine(yval):
    shape = {
            'type': 'line',
            'xref': 'paper',
            'x0': 0,
            'x1': 1,
            #'yref': 'x',
            'y0': yval,
            'y1': yval,
            #'fillcolor': 'blue',
            'opacity': 1,
            'line': {
                'width': 1,
                'color': 'red'
            }
        }
    return shape
    
def plot(*arr, log=True, title=None, legend=True):
    data = []
    shapes = []
    for val in arr:
        # series
        if isinstance(val, Wrapper):
            data.append(go.Scatter(x=val.index, y=val.s, name=val.name, text=val.name))
        elif isinstance(val, pd.Series):
            data.append(go.Scatter(x=val.index, y=val, name=val.name, text=val.name))
        # vertical date line
        elif isinstance(val, datetime.datetime):
            shapes.append(createVerticalLine(val))
        # vertical date line
        elif isinstance(val, np.datetime64):
            shapes.append(createVerticalLine(val.astype(datetime.datetime)))
        # horizontal value line
        elif isinstance(val, numbers.Real):
            shapes.append(createHorizontalLine(val))
        else:
            raise Exception("unsupported value type: " + str(type(val)))
    
    for d in data:
        d = d.y
        if isinstance(d, Wrapper):
            d = d.s
        if np.any(d <= 0):
            log = False
            
    mar = 30
    margin=gol.Margin(
        l=mar,
        r=mar,
        b=mar,
        t=mar,
        pad=0
    )
    
    #bgcolor='#FFFFFFBB',bordercolor='#888888',borderwidth=1,
    if legend:
        legendArgs=dict(x=0,y=1,traceorder='normal',
            bgcolor='rgb(255,255,255,187)',bordercolor='#888888',borderwidth=1,
            font=dict(family='sans-serif',size=12,color='#000'),
        )    
    else:
        legendArgs = {}
    yaxisScale = "log" if log else None
    layout = go.Layout(legend=legendArgs, 
                       showlegend=legend, 
                       margin=margin, 
                       yaxis=dict(type=yaxisScale, autorange=True), 
                       shapes=shapes, 
                       title=title,
                       hovermode = 'closest')
    fig = go.Figure(data=data, layout=layout)
    py.iplot(fig)

# show a stacked area chart normalized to 100% of multiple time series
def plotly_area(df, title=None):
    tt = df.div(df.sum(axis=1), axis=0)*100 # normalize to summ 100
    tt = tt.reindex(tt.mean().sort_values(ascending=False).index, axis=1) # sort columns by mean value
    tt = tt.sort_index()
    tt2 = tt.cumsum(axis=1) # calc cum-sum
    data = []
    for col in tt2:
        s = tt2[col]
        trace = go.Scatter(
            name=col,
            x=s.index.to_datetime(),
            y=s.values,
            text=["{:.1f}%".format(v) for v in tt[col].values], # use text as non-cumsum values
            hoverinfo='name+x+text',
            mode='lines',
            fill='tonexty'
        )
        data.append(trace)

    mar = 30
    margin=gol.Margin(l=mar,r=mar,b=mar,t=mar,pad=0)
    legend=dict(x=0,y=1,traceorder='reversed',
        bgcolor='#FFFFFFBB',bordercolor='#888888',borderwidth=1,
        font=dict(family='sans-serif',size=12,color='#000'),
    )    
    layout = go.Layout(margin=margin, legend=legend, title=title,
        #showlegend=True,
        xaxis=dict(
            type='date',
        ),
        yaxis=dict(
            type='linear',
            range=[1, 100],
            dtick=20,
            ticksuffix='%'
        )
    )
    fig = go.Figure(data=data, layout=layout)
    py.iplot(fig, filename='stacked-area-plot')
       

In [None]:
# data processing

def _start(s):
    return s.index[0]

def _end(s):
    return s.index[-1]

# def getCommonDate(data, alldata=False):
#     if alldata:
#         l = [_start(s) for s in data if isinstance(s, Wrapper) or isinstance(s, pd.Series)]
#     else:
#         l = [_start(s) for s in data if isinstance(s, Wrapper)]
#     if not l:
#         return None
#     return max(l)

def getCommonDate(data, alldata=False, agg=max, get_fault=False):
    if alldata:
        data = [s for s in data if isinstance(s, Wrapper) or isinstance(s, pd.Series)]
    else:
        data = [s for s in data if isinstance(s, Wrapper)]
    if not data:
        return None
    dates = [_start(s) for s in data]
    val = agg(dates)
    if get_fault:
        fault = ", ".join(s.name for date, s in zip(dates, data) if date == val)
        return val, fault
    return val

def doTrim(data, alldata=False, silent=False):
    if silent:
        date = getCommonDate(data, alldata=alldata)
    else:
        date, max_fault = getCommonDate(data, alldata=alldata, get_fault=True)
    if date is None:
        if not silent:
            print("Unable to trim data")
        return data
    newArr = []
    for s in data:
        if isinstance(s, Wrapper) or (alldata and isinstance(s, pd.Series)):
            s = s[date:]
            if s.shape[0] == 0:
                continue
        newArr.append(s)
    if not silent:
        min_date, min_fault = getCommonDate(data, agg=min, get_fault=True)
        print(f"trimmed data from {min_date:%Y-%m-%d} [{min_fault}] to {date:%Y-%m-%d} [{max_fault}]")
    return newArr

def trimBy(trimmed, by):
    if len(by) == 0:
        return []
    start = max(s.index[0] for s in by)
    return [s[start:] for s in trimmed]

def doAlign(data):
    date = getCommonDate(data)
    if date is None:
        return data
    newArr = []
    for s in data:
        if isinstance(s, Wrapper) or isinstance(s, pd.Series):
            #s = s / s[date] # this can sometime fail for messing data were not all series have the same index
            base = s[date:]
            if len(base.s) == 0:
                continue
            s = s / base[0]
        newArr.append(s)
    return newArr

def doClean(data):
    return [s.dropna() if isinstance(s, Wrapper) else s for s in data]

def try_parse_date(s, format):
    try:
        return datetime.datetime.strptime(s, format)
    except ValueError:
        return None    

def easy_try_parse_date(s):
    return try_parse_date(s, "%d/%m/%Y") or try_parse_date(s, "%d.%m.%Y") or try_parse_date(s, "%d-%m-%Y")
    
def show(*data, trim=True, align=True, ta=True, cache=None, mode=None, source=None, silent=False, **plotArgs):
    items = []
    getArgs = {}
    if not mode is None:
        getArgs["mode"] = mode
    if not cache is None:
        getArgs["cache"] = cache
    if not source is None:
        getArgs["source"] = cache
    
    data2 = []
    for x in data:
        if isinstance(x, list):
            data2 += x
        else:
            data2.append(x)
    
    for x in data2:
        if isinstance(x, pd.DataFrame):
            items += [x[c] for c in x]
        elif isinstance(x, datetime.datetime) or isinstance(x, np.datetime64):
            items.append(x)
        elif isinstance(x, str) and easy_try_parse_date(x):
            items.append(easy_try_parse_date(x))
        elif isinstance(x, numbers.Real):
            items.append(x)
        else:
            x = get(x, **getArgs)
            items.append(x)
    data = items
    #data = [get(s) for s in data] # converts string to symbols
    data = doClean(data)
    if not ta:
        trim = False
        align = False
    if trim: data = doTrim(data)
    if align: data = doAlign(data)
        
    if not silent:
        plot(*data, **plotArgs)
    else:
        return [d for d in data if not isinstance(d, numbers.Real)]


    
def showRiskReturn(lst, setlim=True, lines=False, color=None, annotations=None, ret_func=None, risk_func=None):
    if len(lst) == 0:
        return
    if ret_func is None:
        ret_func = cagr
    if risk_func is None:
        risk_func = ulcer
    lst = [get(s) for s in lst]
    if annotations is None:
        if "name" in dir(lst[0]) or "s" in dir(lst[0]):
            annotations = [s.name for s in lst]
    cagrs = [ret_func(s) for s in lst]
    stds = [risk_func(s) for s in lst]
    #stds = [stdmret(s) for s in lst]
    if lines:
        plt.plot(stds, cagrs, marker="o", color=color)
    else:
        plt.scatter(stds, cagrs, color=color)
    if setlim:
        plt.xlim(min(0, min(stds)-0), max(stds)+1)
        plt.ylim(min(0, min(cagrs)-1), max(cagrs)+1)
    plt.axhline(0, color='gray', linewidth=1)
    plt.axvline(0, color='gray', linewidth=1)
    plt.xlabel(risk_func.__name__, fontsize=20)
    plt.ylabel(ret_func.__name__, fontsize=20)
    if annotations:
        for i, txt in enumerate(annotations):
            plt.annotate(txt, (stds[i], cagrs[i]), fontsize=14)

def show_risk_return_ntr_mode(lst, ret_func=None):
    def get_data(lst, mode):
        return get(lst, mode=mode, despike=True, trim=True)

    tr = get_data(lst, "TR")
    ntr = get_data(lst, "NTR")
    showRiskReturn(ntr, ret_func=ret_func)
    for a, b in zip(tr, ntr):
        showRiskReturn([a, b], setlim=False, lines=True, ret_func=ret_func, annotations=False)    

def mix(s1, s2, n=10, **getArgs):
    part = 100/n
    res = []
    for i in range(n+1):
        res.append(get({s1: i*part, s2: (100-i*part)}, **getArgs))
    return res
        
def ma(s, n):
    n = int(n)
    return wrap(s.rolling(n).mean(), "ma({}, {})".format(s.name, n))

def mm(s, n):
    n = int(n)
    return wrap(s.rolling(n).median(), "mm({}, {})".format(s.name, n))

def mmax(s, n):
    n = int(n)
    return wrap(s.rolling(n).max(), "mmax({}, {})".format(s.name, n))

def mmin(s, n):
    n = int(n)
    return wrap(s.rolling(n).min(), "mmin({}, {})".format(s.name, n))

def cagr(s):
    days = (s.index[-1] - s.index[0]).days
    if days <= 0:
        return np.nan
    years = days/365
    val = s[-1] / s[0]
    return (math.pow(val, 1/years)-1)*100

def ulcer(x):
    cmax = np.maximum.accumulate(x)
    r = (x/cmax-1)*100
    return math.sqrt(np.sum(r*r)/x.shape[0])

# std of monthly returns
def stdmret(s):
    return ret(s).std()*math.sqrt(12)*100

def bom(s):
    idx = s.index.values.astype('datetime64[M]') # convert to monthly representation
    idx = np.unique(idx) # remove duplicates
    return s[idx].dropna()
    
def boy(s):
    idx = s.index.values.astype('datetime64[Y]') # convert to monthly representation
    idx = np.unique(idx) # remove duplicates
    return s[idx].dropna()
    
def ret(s):
    return s.pct_change()

def logret(s):
    res = np.log(s) - np.log(s.shift(1))
    res.name = "logret(" + s.name + ")"
    return res

def i_logret(s):
    return np.exp(np.cumsum(s))

def lrret(regressors, target, sum1=False):
    regressors = [get(x) for x in regressors]
    target = get(target)
    all = [logret(x).s for x in (regressors + [target])]
    
    # based on: https://stats.stackexchange.com/questions/21565/how-do-i-fit-a-constrained-regression-in-r-so-that-coefficients-total-1?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
    # NOTE: note finished, not working
    if sum1:
        allOrig = all
        last = all[-2]
        all = [r - last for r in (all[:-2] + [all[-1]])]
        
    data = pd.DataFrame(all).T
    data = data.dropna()
    y = data.iloc[:, -1]
    X = data.iloc[:, :-1]

    regr = linear_model.LinearRegression(fit_intercept=False)
    regr.fit(X, y)
    
    if sum1:
        weights = np.append(regr.coef_, 1-np.sum(regr.coef_))
        
        all = allOrig
        data = pd.DataFrame(all).T
        data = data.dropna()
        y = data.iloc[:, -1]
        X = data.iloc[:, :-1]
        regr = linear_model.LinearRegression(fit_intercept=False)
        regr.fit(X, y)
        
        regr.coef_ = weights
    
    y_pred = regr.predict(X)

    
    print('Regressors:', [s.name for s in regressors])
    print('Coefficients:', regr.coef_)
    #print('Coefficients*:', list(regr.coef_) + [1-np.sum(regr.coef_)])
    #print("Mean squared error: %.2f" % mean_squared_error(diabetes_y_test, diabetes_y_pred))
    print('Variance score r^2: %.3f' % sk.metrics.r2_score(y, y_pred))

    y_pred = i_logret(pd.Series(y_pred, X.index))
    y_pred.name = target.name + " fit"
    #y_pred = "fit"
    y_pred = Wrapper(y_pred)
    show(target , y_pred)
    return y_pred
    
def dd(x):
    if isinstance(x, Wrapper): # not sure why Wrapper doesn't work
        x = x.s
    x = x.dropna()
    res = (x / np.maximum.accumulate(x) - 1) * 100
    return res
    
def percentile(s, p):
    return s.quantile(p/100)   



In [None]:
from IPython.core.display import Javascript
import time, os, stat

def save_notebook(verbose=True, sleep=True):
    Javascript('console.log(document.querySelector("div#save-notbook button").click())')
    if verbose:
        print("save requested, sleeping to ensure execution ..")
    if sleep:
        time.sleep(15)
    if verbose:
        print("done")

# save live notebook at first run to make sure it's the latest modified file in the folder (for later publishing)
save_notebook(False, False)

def publish(name=None):
    def file_age_in_seconds(pathname):
        return time.time() - os.stat(pathname)[stat.ST_MTIME]

    filename = !ls -t *.ipynb | grep -v /$ | head -1
    filename = filename[0]

    age = int(file_age_in_seconds(filename))
    min_age = 5
    if age > min_age:
        print(filename + " file age is " + str(age) + " seconds old, auto saving current notebook ..")
        save_notebook()
        filename = !ls -t *.ipynb | grep -v /$ | head -1
        filename = filename[0]
    
    if not name:
        name = str(uuid.uuid4().hex.upper())
    save()
    print("Publishing " + filename + " ..")
    res = subprocess.call(['bash', './publish.sh', name])
    if res == 0:
        print("published successfuly!")
        print("https://nbviewer.jupyter.org/github/ertpload/test/blob/master/__name__.ipynb".replace("__name__", name))
    else:
        print("Failed!")

In [None]:
from IPython.display import display,Javascript 
def save():
    display(Javascript('IPython.notebook.save_checkpoint();'))

In [None]:
# make the plotly graphs look wider on mobile
from IPython.core.display import display, HTML
s = """
<style>
div.rendered_html {
    max-width: 10000px;
}
</style>
"""
display(HTML(s))

In [None]:
# interception to auto-fetch hardcoded symbols e.g:
# show(SPY)
# this should run last in the framework code, or it attempts to download unrelated symbols :)

from IPython.core.inputtransformer import *
intercept = True
if intercept and not "my_transformer_tokens_instance" in locals():
    #print("transformation hook init")
    attempted_implied_fetches = set()
    
    ip = get_ipython()

    @StatelessInputTransformer.wrap
    def my_transformer(line):
        if line.startswith("x"):
            return "specialcommand(" + repr(line) + ")"
        return line

    @TokenInputTransformer.wrap
    def my_transformer_tokens(tokens):
        for i, x in enumerate(tokens):
            if x.type == 1 and x.string.isupper() and x.string.isalpha(): ## type=1 is NAME token
                if i < len(tokens)-1 and tokens[i+1].type == 53 and tokens[i+1].string == "=":
                    attempted_implied_fetches.add(x.string)
                    continue
                if x.string in attempted_implied_fetches or x.string in ip.user_ns:
                    continue
                try:
                    ip.user_ns[x.string] = get(x.string)
                except:
                    print("Failed to fetch implied symbol: " + x.string)
                    attempted_implied_fetches.add(x.string)
        return tokens

    my_transformer_tokens_instance = my_transformer_tokens()
    
    ip.input_splitter.logical_line_transforms.append(my_transformer_tokens_instance)
    ip.input_transformer_manager.logical_line_transforms.append(my_transformer_tokens_instance)

In [None]:
def date(s):
    return pd.to_datetime(s, format="%Y-%m-%d")


In [None]:
# ************* SYMBOLS ***************
# these are shorthand variables representing asset classes

# ==== SPECIAL ====
# https://www.federalreserve.gov/pubs/bulletin/2005/winter05_index.pdf
# Nominal Daily
usdMajor = 'FRED/DTWEXM@Q' # Trade Weighted U.S. Dollar Index: Major Currencies
usdBroad = 'FRED/DTWEXB@Q' # Trade Weighted U.S. Dollar Index: Broad
usdOther = 'FRED/DTWEXO@Q' # Trade Weighted U.S. Dollar Index: Other Important Trading Partners
# Nominal Monthly
usdMajorM = 'FRED/TWEXMMTH@Q'
usdBroadM = 'FRED/TWEXBMTH@Q'
usdOtherM = 'FRED/TWEXOMTH@Q'
# Real Monthly
usdMajorReal = 'FRED/TWEXMPA@Q' # Real Trade Weighted U.S. Dollar Index: Major Currencies
usdBroadReal = 'FRED/TWEXBPA@Q' # Real Trade Weighted U.S. Dollar Index: Broad
usdOtherReal = 'FRED/TWEXOPA@Q' # Real Trade Weighted U.S. Dollar Index: Other Important Trading Partners
usd = usdBroad

cpiUS ='RATEINF/CPI_USA@Q'


#bitcoinAvg = price("BAVERAGE/USD@Q") # data 2010-2016
#bitcoinBitstamp = price("BCHARTS/BITSTAMPUSD@Q") # data 2011-now

# ==== STOCKS ====
# Global
g_ac = 'VTSMX:45|VGTSX:55' # VT # global all-cap
d_ac = 'URTH' # developed world
# US
ac = 'VTSMX' # VTI # all-cap
lc = 'VFINX' # VOO, SPY # large-cap
mc = 'VIMSX' # VO # mid-cap
sc = 'NAESX' # VB # small-cap
mcc = 'BRSIX' # micro-cap
lcv = 'VIVAX' # IUSV # large-cap-value
mcv = 'VMVIX' # mid-cap-value
scv = 'VISVX' # VBR # small-cap-value
lcg = 'VIGRX' # large-cap-growth 
mcg = 'VMGIX' # mid-cap-growth
scg = 'VISGX' # VBK # small-cap-growth
# ex-US
i_ac = 'VGTSX' # VXUS # intl' all-cap
i_sc = 'VINEX' # VSS, SCZ # intl' small-cap
d_ac = 'VTMGX' # EFA, VEA # intl' developed
i_dev = d_ac # legacy
i_acv = 'DFIVX' # EFV # intl' all-cap-value
i_scv = 'DISVX' # DLS # intl' small-cap-value
em_ac = 'VEIEX' # VWO # emerging markets
em = em_ac # legacy
em_sc = 'EEMS' # emerging markets small cap
fr_ac = 'FRN' # FM # frontier markets

# ==== BONDS ====
# US GOVT
sgb = 'VFISX' # SHY, VGSH # short term govt bonds
tips = 'VIPSX' # TIP # inflation protected treasuries
lgb = 'VUSTX' # TLT, VGLT # long govt bonds
elgb = 'PEDIX' # EDV # extra-long (extended duration) govt bonds
gb = 'VFITX' # IEI # intermediate govt bonds
fgb = 'TFLO' # floating govt bonds
# US CORP 
cb = 'MFBFX' # LQD # corp bonds
scb = 'VCSH' # short-term-corp-bonds
lcb = 'VCLT' # long-term-corp-bonds
fcb = 'FLOT' # floating corp bonds
# US CORP+GOVT
gcb = 'VBMFX' # AGG, BND # govt/corp bonds
sgcb = 'VFSTX' # BSV # short-term-govt-corp-bonds
# International
i_tips = 'WIP' # # intl' local currency inflation protected bonds
i_gcbUsd = 'PFORX' # BNDX # ex-US govt/copr bonds (USD hedged)
i_gbLcl = 'BEGBX' # (getBwx()) BWX, IGOV # ex-US govt bonds (non hedged)
i_gb = i_gbLcl # legacy
i_cb = 'PIGLX' # PICB, ex-US corp bonds
i_cjb = 'IHY' # intl-corp-junk-bonds
g_gcbLcl = 'PIGLX' # Global bonds (non hedged)
g_gcbUsd = 'PGBIX' # Global bonds (USD hedged)
g_sgcb = 'LDUR' # Global short-term govt-corp bonds
g_usgcb = 'MINT' # Global ultra-short-term govt-corp bonds
em_gbUsd = 'FNMIX' # VWOB, EMB # emerging market govt bonds (USD hedged)
emb = em_gbUsd # legacy
em_gbLcl = 'PELBX' # LEMB, EBND, EMLC emerging-markets-govt-bonds (local currency) [LEMB Yahoo data is broken]
em_cjb = 'EMHY' # emerging-markets-corp-junk-bonds
cjb = 'VWEHX' # JNK, HYG # junk bonds
junk = 'cjb' # legacy
scjb = 'HYS' # short-term-corp-junk-bonds

# ==== CASH ====
rfr = 'SHV' # BIL # risk free return (1-3 month t-bills)
cash = 'rfr' # SHV # risk free return
cashLike = 'VFISX:30' # a poor approximation for rfr returns 

# ==== OTHER ====
fedRate = 'FRED/DFF@Q'
reit = 'DFREX' # VNQ # REIT
i_reit = 'RWX' # VNQI # ex-US REIT
g_reit = 'DFREX:50|RWX:50' # RWO # global REIT
gold = 'LBMA/GOLD@Q' # GLD # gold
silver = 'LBMA/SILVER@Q' # SLV # silver
palladium = 'LPPM/PALL@Q'
platinum = 'LPPM/PLAT@Q'
#metals = gold|silver|palladium|platinum # GLTR # precious metals (VGPMX is a stocks fund)
comm = 'DBC' # # commodities
oilWtiQ = 'FRED/DCOILWTICO@Q'
oilBrentQ = 'FRED/DCOILBRENTEU@Q'
oilBrentK = 'oil-prices@OKFN' # only loads first series which is brent
eden = 'EdenAlpha@MAN'

# ==== INDICES ====
spxPR = '^GSPC'
spxTR = '^SP500TR'
spx = spxPR


another options for interception:
```python
class VarWatcher(object):
    def __init__(self, ip):
        self.shell = ip
        self.last_x = None

    def pre_execute(self):
        if False:
            for k in dir(self.shell):
                print(k, ":", getattr(self.shell, k))
                print()
        #print("\n".join(dir(self.shell)))
        if "content" in self.shell.parent_header:
            code = self.shell.parent_header['content']['code']
            self.shell.user_ns[code] = 42
        #print(self.shell.user_ns.get('ASDF', None))

    def post_execute(self):
        pass
        #if self.shell.user_ns.get('x', None) != self.last_x:
        #    print("x changed!")

def load_ipython_extension(ip):
    vw = VarWatcher(ip)
    ip.events.register('pre_execute', vw.pre_execute)
    ip.events.register('post_execute', vw.post_execute)
    
ip = get_ipython()

load_ipython_extension(ip)   

```

In [None]:
def divs(symbolName, period=None, fill=False):
    if isinstance(symbolName, tuple) and period is None:
        symbolName, period = symbolName
    if isinstance(symbolName, Wrapper) or isinstance(symbolName, pd.Series):
        symbolName = symbolName.name
    divs = get(symbolName, mode="divs")
    divs = divs[divs.s>0]
    if period:
        divs = wrap(divs.rolling(period).sum())
    if fill:
        price = get(symbolName)
        divs = divs.reindex(price.index.union(divs.index), fill_value=0)        
    return divs

def getYield(symbolName, period=None, altPriceName=None):
    if isinstance(symbolName, tuple) and period is None:
        symbolName, period = symbolName
    if isinstance(symbolName, Wrapper) or isinstance(symbolName, pd.Series):
        symbolName = symbolName.name
    price = get(altPriceName or symbolName, mode="PR")
    divs = get(symbolName, mode="divs")
    divs = divs[divs.s>0]
    if len(divs.s) == 0:
        return divs
    if period is None:
        monthds_diff = (divs.s.index.to_series().diff().dt.days/30).dropna().apply(lambda x: int(round(x)))
        months = monthds_diff[-5:].median()
        period = int(12 // months)

        #periods = divs.s.index.year.value_counts()
        #periods = periods.sort_index()
        #periods = periods[-5:]
        #period = int(periods.median())
        
        #print(f"auto period for {symbolName} is {period}")
        #print(divs.s.index.year.value_counts())
    if period:
        divs = wrap(divs.rolling(period).sum())
    return name(divs/price*100, divs.name)

def get_curr_yield(s, period=None):
    return getYield(s, period=period).dropna()[-1]

def get_curr_net_yield(s, period=None):
    return getYield(s, period=period).dropna()[-1]*0.75

def get_TR_from_PR_and_divs(pr, divs):
    m = d / pr + 1
    mCP = m.cumprod().fillna(method="ffill")
    tr = pr * mCP
    return wrap(tr, pr.name + " TR")

def despike(s, std=8, window=30, shift=10):
    if isinstance(s, list):
        return [despike(x) for x in s]
    if "s" in dir(s):
        s = s.s
    new_s = s.copy()
    ret = logret(s).fillna(0)
    new_s[(ret - ret.mean()).abs() > ret.shift(shift).rolling(window).std().fillna(ret.max()) * std] = np.nan
    return wrap(new_s.interpolate(), s.name)

## Generic Utils

In [1]:
# safely convert a float/string/mixed series to floats
# to remove commas we need the data type to be "str"
# but if we assume it's "str" wihtout converting first, and some are numbers
# those numbers will become NaN's.
def series_as_float(ser):
    return pd.to_numeric(ser.astype(str).str.replace(",", "").str.replace("%", ""), errors="coerce")

def lmap(f, l):
    return list(map(f, l))