# Fetch the data from yfinance

In [40]:
import pandas as pd
import yfinance as yf
import time
import os
import random
import tqdm
import numpy as np

def fetch_data(ticker, start='2015-01-01', end='2025-01-01', wait=True):
    if os.path.exists(f'data/raw/autoadjusted/{ticker}.csv'):
        print(f'{ticker} already exists')
        return pd.read_csv(f'data/raw/autoadjusted/{ticker}.csv')
    else:
        try:
            data = yf.download(ticker, start=start, end=end, auto_adjust=True)
            data.to_csv(f'data/raw/autoadjusted/{ticker}.csv')
            print(f'{ticker} fetched and saved')
            if wait:
                time.sleep(random.randint(4, 8))
            return data
        except Exception as e:
            print(f'Error fetching {ticker}: {e}')
            return None
        
def get_sp500_tickers():
    """
    Get list of S&P 500 tickers from Wikipedia
    """
    try:
        # Get S&P 500 table from Wikipedia
        url = 'https://en.wikipedia.org/wiki/List_of_S%26P_500_companies'
        table = pd.read_html(url)[0]
        
        # Get tickers from first column
        tickers = table['Symbol'].tolist()
        
        print(f"Retrieved {len(tickers)} S&P 500 tickers")
        return tickers
        
    except Exception as e:
        print(f"Error getting S&P 500 tickers: {e}")
        return None
    
def format_data(data):
    try:
        data = data.iloc[2:].copy()
        data['Date'] = pd.to_datetime(data['Price'])
        data.drop('Price', axis=1, inplace=True)
        data.set_index('Date', inplace=True)
    except Exception as e:
        print(f"Error formatting data for {data.name}: {e}")
        return None
    return data
    
    
    
def import_all_data(folder_with_csv):
    all_data = {}
    for file in os.listdir(folder_with_csv):
        data = pd.read_csv(os.path.join(folder_with_csv, file))
        filename = file.split('.')[0]
        all_data[filename] = data
    return all_data


def verify_data(data_df, first='2010-01-04', last='2024-12-31'):
    # Check if first date is 2015-01-02 and last date is 2024-12-31
    first_date = data_df.index[0]
    last_date = data_df.index[-1]
    
    expected_first = pd.Timestamp(first)
    expected_last = pd.Timestamp(last)
    
    if first_date != expected_first or last_date != expected_last:
        print("Error")
        print(f"Expected range: {expected_first} to {expected_last}")
        print(f"Actual range: {first_date} to {last_date}")
        return False
        
    return True

In [12]:
sp500_tickers = get_sp500_tickers()
print(sp500_tickers[:10])

Retrieved 503 S&P 500 tickers
['MMM', 'AOS', 'ABT', 'ABBV', 'ACN', 'ADBE', 'AMD', 'AES', 'AFL', 'A']


In [13]:
csv_count = len([f for f in os.listdir('data/raw') if f.endswith('.csv')])
print(f"Number of CSV files in data/raw: {csv_count}")

Number of CSV files in data/raw: 503


In [None]:
start = '2010-01-01'
end = '2025-01-01'
    
for ticker in tqdm.tqdm(sp500_tickers):
    csv_path = os.path.join('data/raw/autoadjusted/', f'{ticker}.csv')
    if os.path.exists(csv_path):
        print(f"{ticker} already exists")
    else:
        data = fetch_data(ticker, start, end, wait=True)

100%|██████████| 503/503 [00:00<00:00, 6668.17it/s]

MMM already exists
AOS already exists
ABT already exists
ABBV already exists
ACN already exists
ADBE already exists
AMD already exists
AES already exists
AFL already exists
A already exists
APD already exists
ABNB already exists
AKAM already exists
ALB already exists
ARE already exists
ALGN already exists
ALLE already exists
LNT already exists
ALL already exists
GOOGL already exists
GOOG already exists
MO already exists
AMZN already exists
AMCR already exists
AEE already exists
AEP already exists
AXP already exists
AIG already exists
AMT already exists
AWK already exists
AMP already exists
AME already exists
AMGN already exists
APH already exists
ADI already exists
AON already exists
APA already exists
APO already exists
AAPL already exists
AMAT already exists
APTV already exists
ACGL already exists
ADM already exists
ANET already exists
AJG already exists
AIZ already exists
T already exists
ATO already exists
ADSK already exists
ADP already exists
AZO already exists
AVB already exists




In [18]:
raw_count = len([f for f in os.listdir('data/raw') if f.endswith('.csv')])
autoadj_count = len([f for f in os.listdir('data/raw/autoadjusted') if f.endswith('.csv')])

print(f"Number of CSV files in data/raw: {raw_count}")
print(f"Number of CSV files in data/raw/autoadjusted: {autoadj_count}")


Number of CSV files in data/raw: 503
Number of CSV files in data/raw/autoadjusted: 503


In [24]:
def verify_date_range(df, start="2010-01-04", end="2024-12-31"):
    # All stocks have different start dates, so only check end date
    if len(df.index) == 0:
        return False
    if df.index[-1].strftime('%Y-%m-%d') != end:
        return False
    return True

def verify_nan_values(df):
    if df.isna().any().any():
        return False
    return True

good_data = {}
bad_data = {}

for file in os.listdir('data/raw/autoadjusted'):
    if not file.endswith('.csv'):
        continue
        
    ticker = file[:-4]  # Remove .csv extension
    csv_path = os.path.join('data/raw/autoadjusted', file)
    
    try:
        df = pd.read_csv(csv_path, skiprows=[1,2]) # Skip Ticker and Date rows
        df.rename(columns={'Price': 'Date'}, inplace=True)
        df['Date'] = pd.to_datetime(df['Date'])
        df.set_index('Date', inplace=True)
        
        if not verify_date_range(df) or not verify_nan_values(df):
            print(f"Verification failed for {ticker}")
            if len(df.index) > 0:
                print(f"Date range: {df.index[0]} to {df.index[-1]}")
            else:
                print("Empty dataframe")
            print(f"Nan values: {df.isna().any().any()}")
            print("--------------------------------")
            bad_data[ticker] = df
        else:
            good_data[ticker] = df
            csv_path = os.path.join('data/raw', file)
            # df.to_csv(csv_path)
            # break
            
    except Exception as e:
        print(f"Error processing {ticker}: {str(e)}")
        print("--------------------------------")

In [27]:
ticker = 'A'
df = pd.read_csv(f'data/raw/{ticker}.csv')
df.set_index('Date', inplace=True)
df.index = pd.to_datetime(df.index)
df

Unnamed: 0_level_0,Close,High,Low,Open,Volume
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2010-01-04,19.931620,20.141761,19.823364,19.988930,3815561
2010-01-05,19.715111,19.880677,19.587751,19.874309,4186031
2010-01-06,19.645058,19.740576,19.587745,19.645058,3243779
2010-01-07,19.619595,19.625962,19.422188,19.600489,3095172
2010-01-08,19.613226,19.645067,19.358510,19.511340,3733918
...,...,...,...,...,...
2024-12-24,135.276016,135.276016,133.337810,133.785098,370200
2024-12-26,135.007660,135.156746,134.152866,134.192617,556600
2024-12-27,134.719406,135.395297,133.775161,133.804978,631800
2024-12-30,133.606186,134.769101,132.433337,133.586303,993600


# Preprocess data

In [33]:
# Iterate directly over files in data/raw directory
good_data = {}
error_data = {}

for file in os.listdir('data/raw'):
    if not file.endswith('.csv'):
        continue
        
    ticker = file[:-4]  # Remove .csv extension
    csv_path = os.path.join('data/raw', file)
    
    try:
        data = pd.read_csv(csv_path)
        data.set_index('Date', inplace=True)
        data.index = pd.to_datetime(data.index)
        
        if verify_data(data):
            good_data[ticker] = data
        else:
            error_data[ticker] = data
            print(f"Data verification failed for {ticker}")
            
    except Exception as e:
        print(f"Error processing {ticker}: {str(e)}")
        error_data[ticker] = None

Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00
Actual range: 2013-01-02 00:00:00 to 2024-12-31 00:00:00
Data verification failed for ABBV
Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00
Actual range: 2020-12-10 00:00:00 to 2024-12-31 00:00:00
Data verification failed for ABNB
Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00
Actual range: 2013-11-18 00:00:00 to 2024-12-31 00:00:00
Data verification failed for ALLE
Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00
Actual range: 2012-05-15 00:00:00 to 2024-12-31 00:00:00
Data verification failed for AMCR
Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00
Actual range: 2014-06-06 00:00:00 to 2024-12-31 00:00:00
Data verification failed for ANET
Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00
Actual range: 2011-03-30 00:00:00 to 2024-12-31 00:00:00
Data verification failed for APO
Error
Expected range: 2010-01-04 00:00:00 to 2024-12-31 00:00:00


In [36]:
print(len(good_data))
print(len(error_data))

423
78


In [39]:
close_prices = {ticker: data['Close'] for ticker, data in good_data.items()}
dataset = pd.DataFrame(close_prices)

if not os.path.exists("data/processed/dataset_prices.csv"):
    dataset.to_csv("data/processed/dataset_prices.csv")
    print("Dataset saved!")
    
dataset

Unnamed: 0_level_0,A,AAPL,ABT,ACGL,ACN,ADBE,ADI,ADM,ADP,ADSK,...,WSM,WST,WTW,WY,WYNN,XEL,XOM,YUM,ZBH,ZBRA
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2010-01-04,19.931620,6.424604,18.496670,7.601905,31.492178,37.090000,22.062073,20.791342,26.033312,25.670000,...,7.264576,17.557016,52.100517,9.568550,40.966465,12.367372,38.568714,18.697779,51.990017,28.670000
2010-01-05,19.715111,6.435713,18.347229,7.576549,31.686810,37.700001,22.027241,20.903652,25.893505,25.280001,...,7.466855,17.334608,51.983734,9.771677,43.458023,12.220693,38.719311,18.633829,53.635796,28.620001
2010-01-06,19.645058,6.333344,18.449114,7.543795,32.023655,37.619999,21.985449,20.850796,25.832731,25.340000,...,7.731910,17.165573,52.820610,9.663627,42.887981,12.244164,39.053955,18.500622,53.618462,28.400000
2010-01-07,19.619595,6.321636,18.601961,7.499420,31.993710,36.889999,21.811291,20.632776,25.820568,25.480000,...,8.101590,17.218956,52.664909,9.620409,43.803902,12.191360,38.931263,18.495293,54.848499,27.690001
2010-01-08,19.613226,6.363666,18.697060,7.484628,31.866465,36.689999,21.936684,20.375113,25.784111,26.260000,...,7.976042,17.214504,52.606548,9.531811,43.490051,12.197227,38.775097,18.500622,53.696442,27.600000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2024-12-24,135.276016,257.286682,113.126747,92.669998,356.539642,447.940002,216.161041,49.560802,293.422546,301.230011,...,185.572479,331.955566,313.332672,27.837299,88.309021,66.534348,104.494308,133.709183,106.553841,395.440002
2024-12-26,135.007660,258.103729,113.629486,92.930000,355.356537,450.160004,216.131317,49.541222,294.184692,300.279999,...,185.473633,332.613922,315.156952,27.748648,88.836212,66.505081,104.582695,134.699615,106.504189,396.850006
2024-12-27,134.719406,254.685867,113.353470,92.339996,351.166321,446.480011,215.070786,49.511856,293.145447,297.589996,...,183.911789,332.404419,313.610291,27.571339,88.209557,66.466064,104.572876,133.936981,106.126839,389.070007
2024-12-30,133.606186,251.307877,111.194641,91.889999,347.528290,445.799988,210.679962,49.012627,289.968323,297.529999,...,184.366486,328.015228,310.566467,27.472834,85.374672,65.929565,103.865776,132.243332,104.902634,383.850006


In [42]:
# Calculate returns and log returns 
dataset = pd.read_csv('data/processed/dataset_prices.csv')
dataset = dataset.set_index('Date')
dataset = dataset.astype(float)

# Calculate simple returns
returns = dataset.pct_change()
# Calculate log returns
log_returns = np.log(dataset/dataset.shift(1))

# Save both returns datasets if they don't exist
if not os.path.exists("data/processed/dataset_returns.csv"):
    returns.to_csv("data/processed/dataset_returns.csv")
    print("Returns dataset saved!")
    
if not os.path.exists("data/processed/dataset_log_returns.csv"):
    log_returns.to_csv("data/processed/dataset_log_returns.csv") 
    print("Log returns dataset saved!")
    
returns

Returns dataset saved!
Log returns dataset saved!


Unnamed: 0_level_0,A,AAPL,ABT,ACGL,ACN,ADBE,ADI,ADM,ADP,ADSK,...,WSM,WST,WTW,WY,WYNN,XEL,XOM,YUM,ZBH,ZBRA
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2010-01-04,,,,,,,,,,,...,,,,,,,,,,
2010-01-05,-0.010863,0.001729,-0.008079,-0.003336,0.006180,0.016446,-0.001579,0.005402,-0.005370,-0.015193,...,0.027845,-0.012668,-0.002241,0.021229,0.060819,-0.011860,0.003905,-0.003420,0.031656,-0.001744
2010-01-06,-0.003553,-0.015906,0.005553,-0.004323,0.010630,-0.002122,-0.001897,-0.002529,-0.002347,0.002373,...,0.035498,-0.009751,0.016099,-0.011058,-0.013117,0.001921,0.008643,-0.007149,-0.000323,-0.007687
2010-01-07,-0.001296,-0.001849,0.008285,-0.005882,-0.000935,-0.019405,-0.007922,-0.010456,-0.000471,0.005525,...,0.047812,0.003110,-0.002948,-0.004472,0.021356,-0.004313,-0.003142,-0.000288,0.022941,-0.025000
2010-01-08,-0.000325,0.006648,0.005112,-0.001972,-0.003977,-0.005422,0.005749,-0.012488,-0.001412,0.030612,...,-0.015497,-0.000259,-0.001108,-0.009209,-0.007165,0.000481,-0.004011,0.000288,-0.021004,-0.003250
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2024-12-24,0.011144,0.011478,0.003937,0.006298,0.007972,0.002686,0.014891,0.004364,0.007853,0.012572,...,0.006487,0.009710,0.005248,0.006410,0.007604,0.007236,0.000941,0.008516,0.004494,0.012262
2024-12-26,-0.001984,0.003176,0.004444,0.002806,-0.003318,0.004956,-0.000138,-0.000395,0.002597,-0.003154,...,-0.000533,0.001983,0.005822,-0.003185,0.005970,-0.000440,0.000846,0.007407,-0.000466,0.003566
2024-12-27,-0.002135,-0.013242,-0.002429,-0.006349,-0.011792,-0.008175,-0.004907,-0.000593,-0.003533,-0.008958,...,-0.008421,-0.000630,-0.004908,-0.006390,-0.007054,-0.000587,-0.000094,-0.005662,-0.003543,-0.019604
2024-12-30,-0.008263,-0.013263,-0.019045,-0.004873,-0.010360,-0.001523,-0.020416,-0.010083,-0.010838,-0.000202,...,0.002472,-0.013204,-0.009706,-0.003573,-0.032138,-0.008072,-0.006762,-0.012645,-0.011535,-0.013417


# Process sector datasets

In [47]:
# Fetch S&P 500 sectors from Wikipedia
import pandas as pd

# Read the S&P 500 table from Wikipedia
sp500_wiki = pd.read_html('https://en.wikipedia.org/wiki/List_of_S%26P_500_companies')[0]

sectors_df = sp500_wiki[['Symbol', 'GICS Sector']].set_index('Symbol')
our_tickers = dataset.columns.tolist()

# Create mapping of tickers to sectors for our dataset
sectors_mapping = {}
for ticker in our_tickers:
    if ticker in sectors_df.index:
        sectors_mapping[ticker] = sectors_df.loc[ticker, 'GICS Sector']
    else:
        print(f"Ticker {ticker} not found in S&P 500")
        sectors_mapping[ticker] = 'unknown'

sectors_df = pd.DataFrame.from_dict(sectors_mapping, orient='index', columns=['Sector'])
if not os.path.exists("data/raw/tickers_sectors.csv"):
    sectors_df.to_csv("data/raw/tickers_sectors.csv")
    print("Sectors mapping saved!")
    
    
sectors_df.Sector.unique()

Sectors mapping saved!


array(['Health Care', 'Information Technology', 'Financials',
       'Consumer Staples', 'Industrials', 'Utilities', 'Materials',
       'Real Estate', 'Consumer Discretionary', 'Energy',
       'Communication Services'], dtype=object)

In [48]:
sectors = pd.read_csv("data/raw/tickers_sectors.csv", index_col=0)
sector_list = sectors.Sector.unique().tolist()

# Create sectoral datasets
for sector in sectors.Sector.unique():
    # Get tickers for this sector
    sector_tickers = sectors[sectors.Sector == sector].index.tolist()
    
    sector_returns = returns[sector_tickers]
    sector_log_returns = log_returns[sector_tickers]
    
    sector_dir = f"data/processed/sectors/{sector.lower().replace(' ', '_')}"
    if not os.path.exists(sector_dir):
        os.makedirs(sector_dir)
        
    # Save datasets
    sector_returns.to_csv(f"{sector_dir}/returns.csv")
    sector_log_returns.to_csv(f"{sector_dir}/log_returns.csv")
    
print("Sectoral datasets created!")

Sectoral datasets created!
