In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

file_path = 'Industry.xlsx - Prices.csv'
industry_data = pd.read_csv(file_path)
industry_data.head()

class IndustryAnalysis:
    def __init__(self, price_data):
        self.price_data = price_data
        self.price_data['Dates'] = pd.to_datetime(self.price_data['Dates'])
    
    def calc_moving_average(self, industry, window=5):
        """
        Calculate weekly moving average for a given industry.
        
        :param industry: Name of the industry column.
        :param window: Number of weeks to calculate moving average.
        :return: DataFrame with moving averages.
        """
        if industry not in self.price_data.columns:
            raise ValueError(f"Industry '{industry}' not found in data.")
        
        moving_average = self.price_data.set_index('Dates')[industry].rolling(window=window).mean()
        return moving_average

    def calc_period_returns(self, industry):
        """
        Calculate weekly returns for prices of a given industry.
        
        :param industry: Name of the industry column.
        :return: DataFrame with weekly returns.
        """
        if industry not in self.price_data.columns:
            raise ValueError(f"Industry '{industry}' not found in data.")
        
        weekly_returns = self.price_data.set_index('Dates')[industry].pct_change(periods=1)
        return weekly_returns

    def events_plot(self, events, ax):
        """
        Plot shaded rectangles for events with labels.

        :param events: List of tuples with event start date, end date, and name.
        :param ax: Matplotlib Axes object.
        """
        for start, end, name in events:
            ax.axvspan(pd.to_datetime(start), pd.to_datetime(end), color='gray', alpha=0.3, ymax=1)  # Set ymax to 1
            mid_point = pd.to_datetime(start) + (pd.to_datetime(end) - pd.to_datetime(start)) / 2
            ax.text(mid_point, ax.get_ylim()[1], name, horizontalalignment='center', verticalalignment='bottom', fontsize=8, rotation=45)

    def series_plot(self, industry, plot_type='returns'):
        """
        Plot periodic returns or moving average volume for a given industry.
        
        :param industry: Name of the industry.
        :param plot_type: Type of plot ('returns' or 'moving_average').
        """
        if plot_type not in ['returns', 'moving_average']:
            raise ValueError("plot_type must be 'returns' or 'moving_average'.")

        fig, ax = plt.subplots(figsize=(12, 6))
        
        if plot_type == 'returns':
            data = self.calc_period_returns(industry)
            ax.plot(data.index, data, label=f'{industry} Weekly Returns', color='blue')
        else:
            data = self.calc_moving_average(industry)
            ax.plot(data.index, data, label=f'{industry} Moving Average', color='green')
        
        ax.set_title(f'{industry} {plot_type.replace("_", " ").title()}')
        ax.set_xlabel('Date')
        ax.set_ylabel('Value')
        ax.legend()
        plt.show()

    def plot_all(self, industry, events, plot_type='returns'):
        """
        Plot all data, including events overlay.

        :param industry: Name of the industry.
        :param events: List of tuples with event start and end dates.
        :param plot_type: Type of plot ('returns' or 'moving_average').
        """
        fig, ax = plt.subplots(figsize=(12, 6))
        
        if plot_type == 'returns':
            data = self.calc_period_returns(industry)
            max_value = data.max()
            ax.plot(data.index, data, label=f'{industry} Weekly Returns', color='blue')
        else:
            data = self.calc_moving_average(industry)
            max_value = data.max()
            ax.plot(data.index, data, label=f'{industry} Moving Average', color='green')
        
        self.events_plot(events, ax)

        ax.set_title(f'{industry} {plot_type.replace("_", " ").title()} with Events')
        ax.set_xlabel('Date')
        ax.set_ylabel('Value')
        ax.legend()
        plt.show()

        
        updated_events = [('2000-01-01', '2000-05-01', 'Event 1'), ('2001-05-01', '2001-06-01', 'Event 2')]

industry_analysis = IndustryAnalysis(industry_data)
industry_analysis.series_plot('Healthcare', plot_type='returns')
industry_analysis.plot_all('Financials', updated_events, plot_type='returns')
industry_analysis.plot_all('Financials', updated_events, plot_type='moving_average')