In [None]:
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import yfinance as yf
import logging
import os
import json
import time
from datetime import datetime, timedelta

pio.renderers.default = 'notebook'

class StockResearch:
    COLORS = {
        'primary': '#2563eb',      # Blue
        'secondary': '#10b981',    # Green
        'accent': '#f59e0b',       # Orange
        'danger': '#ef4444',       # Red
        'bg_dark': '#1e293b',      # Dark background
        'bg_light': '#334155',     # Light background
        'text': '#f1f5f9',         # Light text
        'grid': 'rgba(148, 163, 184, 0.1)'
    }
    
    def __init__(self, file=None, configs=None):
        self.logger = logging.getLogger(__name__)
        self.folder = 'data'
        self.file = file
        self.configs = configs
        self.stocks_existing, configs_existing = self._load_stocks(), self._load_configs()
        if self.stocks_existing:
            if not configs_existing:
                self.interval = '1wk'
                self.freq = 'quarterly'
            self._get_data()
    
    def _load_stocks(self):
        if self.file is None:
            return False
        try:
            with open(self.file, 'r') as f:
                self.tickers = [i.strip() for i in f.readlines()]
                return bool(self.tickers)
        except FileNotFoundError:
            self.logger.error(f'{self.file} not found')
            return False
    
    def _load_configs(self):
        if self.configs is None:
            return False
        try:
            with open(self.configs, 'r') as f:
                content = json.load(f)
                self.interval = content.get('interval', '1wk')
                self.freq = content.get('freq', 'quarterly')
                return True
        except FileNotFoundError:
            self.logger.warning('Configs file not found')
            return False
    
    def _get_data(self):
        if not os.path.exists(self.folder):
            os.makedirs(self.folder)
        for ticker in self.tickers:
            if not os.path.exists(f'{self.folder}/{ticker}_fundamental_data_{self.freq}.csv'):
                self._get_fundamentals(ticker)
                time.sleep(0.5)
            if not os.path.exists(f'{self.folder}/{ticker}_price_data_{self.interval}.csv'):
                self._get_ohlc(ticker)
                time.sleep(0.5)
    
    def _get_fundamentals(self, ticker):
        try:
            obj = yf.Ticker(ticker=ticker)
            balance = obj.get_balancesheet(freq=self.freq)
            income = obj.get_income_stmt(freq=self.freq)
            cashflow = obj.get_cashflow(freq=self.freq)
            df = pd.concat([income, cashflow], join='inner', axis=0)
            df = df[df.columns[::-1]]
            if self.freq == 'quarterly':
                df.columns = df.columns.to_period('Q').astype('str')
            else:
                df.columns = df.columns.year.astype('str')
            df.to_csv(f'{self.folder}/{ticker}_fundamental_data_{self.freq}.csv')
        except Exception as e:
            self.logger.warning(f'Could not fetch fundamentals for {ticker}: {e}')
    
    def _get_ohlc(self, ticker):
        try:
            df = yf.download(ticker, interval=self.interval, period='max', auto_adjust=True, progress=False)
            df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
            df.to_csv(f'{self.folder}/{ticker}_price_data_{self.interval}.csv')
        except Exception as e:
            self.logger.warning(f'Could not fetch OHLC for {ticker}: {e}')
    
    def _get_base_layout(self, title, height=600):
        
        return dict(
            template='plotly_dark',
            paper_bgcolor=self.COLORS['bg_dark'],
            plot_bgcolor=self.COLORS['bg_light'],
            font=dict(family='Arial, sans-serif', size=12, color=self.COLORS['text']),
            title=dict(text=f'<b>{title}</b>', font=dict(size=18), x=0.5, xanchor='center'),
            margin=dict(l=60, r=60, t=80, b=60),
            height=height,
            xaxis=dict(gridcolor=self.COLORS['grid'], showgrid=True),
            yaxis=dict(gridcolor=self.COLORS['grid'], showgrid=True),
            hovermode='x unified',
            hoverlabel=dict(bgcolor=self.COLORS['bg_dark'], font_size=11)
        )
    
    def price_to_earnings(self):
        
        if not self.stocks_existing:
            return []
        
        pe_df = pd.DataFrame()
        for ticker in self.tickers:
            try:
                df = pd.read_csv(f'{self.folder}/{ticker}_fundamental_data_{self.freq}.csv', index_col=0)
                if self.freq == 'yearly':
                    diluted_eps = df.loc['DilutedEPS'].dropna().iloc[-1]
                else:
                    if len(df.loc['DilutedEPS'].dropna()) < 4:
                        fill = (4 - len(df.loc['DilutedEPS'].dropna())) * df.loc['DilutedEPS'].dropna().mean()
                        diluted_eps = df.loc['DilutedEPS'].dropna().sum() + fill
                    else:
                        diluted_eps = df.loc['DilutedEPS'].dropna().iloc[-4:].sum()
                
                df = pd.read_csv(f'{self.folder}/{ticker}_price_data_{self.interval}.csv', index_col=0)
                price = df['Close'].iloc[-1]
                pe = price / diluted_eps if diluted_eps > 0 else np.nan
                pe_df.loc[ticker, 'Price'] = round(price, 2)
                pe_df.loc[ticker, 'PE_TTM'] = round(pe, 2)
            except Exception as e:
                self.logger.warning(f'Could not calculate PE for {ticker}')
            
            try:
                time.sleep(0.5)
                obj = yf.Ticker(ticker=ticker)
                fpe = obj.info.get('forwardPE')
                pe_df.loc[ticker, 'Forward_PE'] = round(fpe, 2) if fpe else np.nan
            except:
                pe_df.loc[ticker, 'Forward_PE'] = np.nan
            
            if not pd.isna(pe_df.loc[ticker, 'PE_TTM']) and not pd.isna(pe_df.loc[ticker, 'Forward_PE']):
                discount = ((pe_df.loc[ticker, 'PE_TTM'] / pe_df.loc[ticker, 'Forward_PE']) - 1) * 100
                pe_df.loc[ticker, 'Discount(%)'] = round(discount, 2)
        
        if not pe_df.empty:
            pe_df.sort_values(by='Discount(%)', ascending=False, inplace=True)
            print('\n=== Price/Earnings Analysis ===')
            print(pe_df.to_string())
            return list(pe_df.index)
        return []
    
    def return_correlation(self, stocks=None):
        
        if not self.stocks_existing:
            return
        
        tickers = stocks if stocks else self.tickers
        to_merge = []
        
        for ticker in tickers:
            try:
                df = pd.read_csv(f'{self.folder}/{ticker}_price_data_{self.interval}.csv', 
                                index_col=0, parse_dates=True)
                to_merge.append(df['Close'].rename(ticker))
            except FileNotFoundError:
                continue
        
        merged_df = pd.concat(to_merge, join='inner', axis=1)
        corr = merged_df.pct_change().corr()
        
        fig = go.Figure(data=go.Heatmap(
            z=corr.values,
            x=corr.columns,
            y=corr.columns,
            colorscale='RdBu',
            zmid=0,
            text=corr.values,
            texttemplate='%{text:.2f}',
            textfont=dict(size=10),
            colorbar=dict(title='Correlation', thickness=15, len=0.7)
        ))
        
        layout = self._get_base_layout('Return Correlation Matrix', height=650)
        layout['xaxis'].update(tickangle=-45)
        fig.update_layout(**layout)
        fig.show()
    
    def growth(self, stocks=None):
        
        if not self.stocks_existing:
            return
        
        tickers = stocks if stocks else self.tickers
        
        for ticker in tickers:
            try:
                df = pd.read_csv(f'{self.folder}/{ticker}_fundamental_data_{self.freq}.csv', index_col=0)
                required = df.loc[['NetIncome', 'FreeCashFlow', 'TotalRevenue']].dropna(axis=1)
                
                netmargin = (required.loc['NetIncome'] / required.loc['TotalRevenue'] * 100)
                fcfmargin = (required.loc['FreeCashFlow'] / required.loc['TotalRevenue'] * 100)
                revgrowth = (required.loc['TotalRevenue'].pct_change().fillna(0) * 100)
                
                fig = make_subplots(
                    rows=2, cols=1,
                    subplot_titles=('Profit Margins', 'Revenue Growth'),
                    row_heights=[0.6, 0.4],
                    vertical_spacing=0.12
                )
                
                # Margins
                fig.add_trace(go.Bar(
                    x=netmargin.index, y=netmargin.values,
                    name='Net Margin',
                    marker_color=self.COLORS['primary']
                ), row=1, col=1)
                
                fig.add_trace(go.Bar(
                    x=fcfmargin.index, y=fcfmargin.values,
                    name='FCF Margin',
                    marker_color=self.COLORS['secondary']
                ), row=1, col=1)
                
                # Growth
                colors = [self.COLORS['secondary'] if x >= 0 else self.COLORS['danger'] 
                         for x in revgrowth.values]
                fig.add_trace(go.Bar(
                    x=revgrowth.index, y=revgrowth.values,
                    name='Revenue Growth',
                    marker_color=colors,
                    showlegend=False
                ), row=2, col=1)
                
                layout = self._get_base_layout(f'{ticker} - Growth & Profitability', height=700)
                layout['yaxis1'] = dict(title='Margin (%)', gridcolor=self.COLORS['grid'])
                layout['yaxis2'] = dict(title='Growth (%)', gridcolor=self.COLORS['grid'])
                layout['barmode'] = 'group'
                fig.update_layout(**layout)
                fig.show()
                
            except Exception as e:
                self.logger.warning(f'Could not process {ticker}: {e}')
    
    def risk_and_return(self, startdate=None, stocks=None):
        
        if not self.stocks_existing:
            return
        
        tickers = stocks if stocks else self.tickers
        if startdate:
            try:
                startdate = pd.to_datetime(startdate)
            except:
                self.logger.error('Invalid date format')
                return
        
        to_merge = []
        stats = pd.DataFrame()
        
        for ticker in tickers:
            try:
                df = pd.read_csv(f'{self.folder}/{ticker}_price_data_{self.interval}.csv', 
                                index_col=0, parse_dates=True)['Close']
                if startdate:
                    df = df.loc[df.index >= startdate]
                to_merge.append(df.rename(ticker))
                
                annualize_map = {'1d': 252, '1wk': 52, '1mo': 12}
                annualize = annualize_map.get(self.interval)
                
                if annualize:
                    returns = df.pct_change().dropna() * 100
                    rfr = 4
                    mean_ret = returns.mean() * annualize
                    std_ret = returns.std() * np.sqrt(annualize)
                    sharpe = (mean_ret - rfr) / std_ret if std_ret > 0 else 0
                    
                    total_ret = ((df.iloc[-1] / df.iloc[0]) - 1) * 100
                    max_dd = ((df / df.cummax()) - 1).min() * 100
                    
                    stats.loc[ticker, 'Total_Return(%)'] = round(total_ret, 2)
                    stats.loc[ticker, 'Avg_Return(%)'] = round(mean_ret, 2)
                    stats.loc[ticker, 'Volatility(%)'] = round(std_ret, 2)
                    stats.loc[ticker, 'Sharpe'] = round(sharpe, 2)
                    stats.loc[ticker, 'Max_DD(%)'] = round(max_dd, 2)
            except Exception as e:
                self.logger.warning(f'Error processing {ticker}: {e}')
        
        if not stats.empty:
            print('\n=== Risk & Return Analysis ===')
            print(stats.sort_values(by='Sharpe', ascending=False).to_string())
            print()
        
        # Portfolio chart
        allstocks_df = pd.concat(to_merge, join='inner', axis=1)
        portfolio_value = (allstocks_df.mean(axis=1) / allstocks_df.iloc[0].mean() - 1) * 100
        
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=allstocks_df.index,
            y=portfolio_value,
            mode='lines',
            name='Portfolio',
            line=dict(color=self.COLORS['primary'], width=2),
            fill='tozeroy',
            fillcolor='rgba(37, 99, 235, 0.1)'
        ))
        
        fig.add_hline(y=0, line_dash='dash', line_color=self.COLORS['text'], opacity=0.3)
        
        layout = self._get_base_layout('Equal-Weight Portfolio Return', height=550)
        layout['yaxis']['title'] = 'Return (%)'
        layout['xaxis']['rangeslider'] = dict(visible=True, bgcolor=self.COLORS['bg_light'])
        fig.update_layout(**layout)
        fig.show()
    
    def diversification(self, stocks=None):
        
        if not self.stocks_existing:
            return
        
        tickers = stocks if stocks else self.tickers
        sectors = {}
        
        for ticker in tickers:
            try:
                stock = yf.Ticker(ticker)
                industry = stock.info.get('industry', 'Unknown')
                sector = stock.info.get('sector', 'Unknown')
                sectors[ticker] = sector
                time.sleep(0.5)
            except:
                sectors[ticker] = 'Unknown'
        
        sector_counts = pd.Series(sectors).value_counts()
        
        fig = go.Figure(data=[go.Pie(
            labels=sector_counts.index,
            values=sector_counts.values,
            hole=0.4,
            marker=dict(
                colors=['#2563eb', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899', '#14b8a6', '#f97316'],
                line=dict(color=self.COLORS['bg_dark'], width=2)
            ),
            textinfo='label+percent',
            textfont=dict(size=11),
            hovertemplate='<b>%{label}</b><br>Count: %{value}<br>%{percent}<extra></extra>'
        )])
        
        fig.add_annotation(
            text=f'<b>{len(tickers)}</b><br>Stocks<br>{len(sector_counts)} Sectors',
            showarrow=False,
            font=dict(size=14, color=self.COLORS['text']),
            x=0.5, y=0.5
        )
        
        layout = self._get_base_layout('Diversification', height=600)
        fig.update_layout(**layout)
        fig.show()
    
    def momentum_score(self, stocks=None):

        if not self.stocks_existing:
            return
        
        tickers = stocks if stocks else self.tickers
        momentum = pd.DataFrame()
        
        for ticker in tickers:
            try:
                df = pd.read_csv(f'{self.folder}/{ticker}_price_data_{self.interval}.csv', 
                                index_col=0, parse_dates=True)['Close']
                
                # Calculate returns over different periods
                ret_1 = ((df.iloc[-1] / df.iloc[-4]) - 1) * 100 if len(df) >= 4 else np.nan
                ret_2 = ((df.iloc[-1] / df.iloc[-12]) - 1) * 100 if len(df) >= 12 else np.nan
                ret_3 = ((df.iloc[-1] / df.iloc[-24]) - 1) * 100 if len(df) >= 24 else np.nan
                
                momentum.loc[ticker, 'shortterm'] = round(ret_1, 2)
                momentum.loc[ticker, 'midterm'] = round(ret_2, 2)
                momentum.loc[ticker, 'longterm'] = round(ret_3, 2)
                
                # Simple momentum score (average of normalized returns)
                score = np.nanmean([ret_1, ret_2, ret_3])
                momentum.loc[ticker, 'Score'] = round(score, 2)
                
            except Exception as e:
                self.logger.warning(f'Could not calculate momentum for {ticker}')
        
        if not momentum.empty:
            momentum.sort_values(by='Score', ascending=False, inplace=True)
            print('\n=== Momentum Analysis ===')
            print(momentum.to_string())
            print()
            
            # Visualization
            fig = go.Figure(data=[go.Bar(
                x=momentum.index,
                y=momentum['Score'],
                marker_color=[self.COLORS['secondary'] if x >= 0 else self.COLORS['danger'] 
                             for x in momentum['Score']],
                text=momentum['Score'],
                texttemplate='%{text:.1f}%',
                textposition='outside'
            )])
            
            layout = self._get_base_layout('Momentum Score', height=500)
            layout['yaxis']['title'] = 'Score (%)'
            layout['xaxis']['tickangle'] = -45
            fig.update_layout(**layout)
            fig.show()
    
    def __call__(self):

        dc = self.price_to_earnings()
        self.diversification(dc)
        self.momentum_score(dc)
        self.return_correlation(dc)
        self.risk_and_return('2020-06-01', dc)
        self.growth(dc)


if __name__ == '__main__':
    logging.basicConfig(level=logging.WARNING)
    obj = StockResearch('lists/pf.txt', 'configs/configs_short.json')
    obj()