In [36]:
import datetime as dt
import matplotlib.pyplot as plt
from matplotlib import style
import numpy as np
import os
import mplfinance as mpf
import pandas as pd
import pandas_datareader.data as web
import bs4 as bs
import pickle
import requests

# go to wiki to scrape the sp500 table into a pickle object 


In [13]:

def save_sp500_tickers():
    resp = requests.get('https://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
    soup = bs.BeautifulSoup(resp.text)
    table = soup.find('table',{'class': 'wikitable sortable'})
    tickers = []
    for row in table.findAll('tr')[1:]: # skip the header(column names)
        ticker= row.findAll('td')[0].text # the first column
        tickers.append(ticker)
        
    with open("sp500tickers.pickle",'wb') as f: # write as byte
        pickle.dump(tickers,f) # save the data as a list
        
    return tickers

In [40]:
save_sp500_tickers()

['MMM\n',
 'ABT\n',
 'ABBV\n',
 'ABMD\n',
 'ACN\n',
 'ATVI\n',
 'ADBE\n',
 'AMD\n',
 'AAP\n',
 'AES\n',
 'AFL\n',
 'A\n',
 'APD\n',
 'AKAM\n',
 'ALK\n',
 'ALB\n',
 'ARE\n',
 'ALXN\n',
 'ALGN\n',
 'ALLE\n',
 'AGN\n',
 'ADS\n',
 'LNT\n',
 'ALL\n',
 'GOOGL\n',
 'GOOG\n',
 'MO\n',
 'AMZN\n',
 'AMCR\n',
 'AEE\n',
 'AAL\n',
 'AEP\n',
 'AXP\n',
 'AIG\n',
 'AMT\n',
 'AWK\n',
 'AMP\n',
 'ABC\n',
 'AME\n',
 'AMGN\n',
 'APH\n',
 'ADI\n',
 'ANSS\n',
 'ANTM\n',
 'AON\n',
 'AOS\n',
 'APA\n',
 'AIV\n',
 'AAPL\n',
 'AMAT\n',
 'APTV\n',
 'ADM\n',
 'ANET\n',
 'AJG\n',
 'AIZ\n',
 'T\n',
 'ATO\n',
 'ADSK\n',
 'ADP\n',
 'AZO\n',
 'AVB\n',
 'AVY\n',
 'BKR\n',
 'BLL\n',
 'BAC\n',
 'BK\n',
 'BAX\n',
 'BDX\n',
 'BRK.B\n',
 'BBY\n',
 'BIIB\n',
 'BLK\n',
 'BA\n',
 'BKNG\n',
 'BWA\n',
 'BXP\n',
 'BSX\n',
 'BMY\n',
 'AVGO\n',
 'BR\n',
 'BF.B\n',
 'CHRW\n',
 'COG\n',
 'CDNS\n',
 'CPB\n',
 'COF\n',
 'CPRI\n',
 'CAH\n',
 'KMX\n',
 'CCL\n',
 'CARR\n',
 'CAT\n',
 'CBOE\n',
 'CBRE\n',
 'CDW\n',
 'CE\n',
 'CNC\n',
 'CNP\

# get 500 ticker from yahoo

In [15]:
def get_data_from_yahoo(reload_sp500=False):
    if reload_sp500:
        tickers = save_sp500_tickers() #when the sp500 changes
    else:
        with open("sp500tickers.pickle",'rb') as f: # write as byte
            tickers = pickle.load(f) # save the data as a list
    
    if not os.path.exists('stock_dfs'):
        os.makedirs('stock_dfs')
        
    start = dt.datetime(2000,1,1)
    end = dt.datetime(2016,12,31)
    # get data from all the tickers
    for ticker in tickers:
        if not os.path.exists('stock_dfs/{}.csv'.format(ticker)):
            ticker = ticker.replace("\n", "")
            df = web.DataReader(ticker,'yahoo',start,end)
            df.to_csv('stock_dfs/{}.csv'.format(ticker))
            print(f"Finish with {ticker}")
        else:
            print('Already have {}'.format(ticker))

In [16]:
#get_data_from_yahoo()

# combine 500 in one dataframe

In [17]:
def compile_date():
    os.chdir('stock_dfs')
    csvs = os.listdir()
        
        
    main_df = pd.DataFrame()
    
    for count,csv in enumerate(csvs):

        df = pd.read_csv(csv)
        df.set_index('Date',inplace=True)
        
        df.rename(columns={'Adj Close': csv},inplace=True)
        df.drop(['Open','High','Low','Close','Volume'],1,inplace =True)
        
        if main_df.empty:
            main_df=df
        else:
            main_df=main_df.join(df,how='outer')
            
        if count % 10 ==0:
            print(count)
    os.chdir('..')
    main_df.to_csv('sp500tickers_joined_closes.csv')
    

In [22]:
#compile_date()

# correlations

##### def visualize_data():
    df = pd.read_csv('sp500tickers_joined_closes.csv')
    df.columns = [x.split('.')[0] for x in df.columns]
    df_corr = df.corr()
    data = df_corr.values
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    
    heatmap = ax.pcolor(data, cmap=plt.cm.RdYlGn)
    fig.colorbar(heatmap)
    ax.set_xticks(np.arange(data.shape[0])+ 0.5, minor = False)
    ax.set_yticks(np.arange(data.shape[1])+ 0.5, minor = False)
    ax.invert_yaxis()
    ax.xaxis.tick_top()
    
    column_labels=df_corr.columns
    row_labels = df_corr.index
    
    ax.set_xticklabels(column_labels)
    ax.set_yticklabels(row_labels)
    plt.xticks(rotation=90)
    heatmap.set_clim(-1,1)
    plt.tight_layout()
    plt.show()

    
visualize_data()