In [None]:
import sqlite3
import pandas as pd
import plotly.graph_objects as go
from datetime import datetime, timedelta
import logging
from IPython.display import display, HTML
from typing import Optional, Tuple, Union
import ipywidgets as widgets
from ipywidgets import interact

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class ExportAnalyzer:
    def __init__(self, db_path: str = "esr_data.db"):
        self.db_path = db_path
        self.metrics = {
            'weeklyExports': 'Weekly Exports',
            'accumulatedExports': 'Accumulated Exports',
            'outstandingSales': 'Outstanding Sales',
            'grossNewSales': 'Gross New Sales',
            'netSales': 'Net Sales',
            'totalCommitment': 'Total Commitment'
        }
        self.commodities = self.get_commodities()
        self.countries = self.get_countries()
        self.my_dates = None
        self.unit_info = None
        self.my_range = None
        self.exports_df = None

    def get_commodities(self) -> pd.DataFrame:
        with sqlite3.connect(self.db_path) as conn:
            return pd.read_sql("""
                SELECT commodityCode, commodityName 
                FROM metadata_commodities 
                ORDER BY commodityName
            """, conn)

    def get_countries(self) -> pd.DataFrame:
        with sqlite3.connect(self.db_path) as conn:
            return pd.read_sql("""
                SELECT DISTINCT countryCode, countryName 
                FROM metadata_countries 
                ORDER BY countryName
            """, conn)

    def get_unit_info(self, commodity_code: int) -> dict:
        with sqlite3.connect(self.db_path) as conn:
            info = pd.read_sql("""
                SELECT 
                    m.commodityCode,
                    m.commodityName,
                    m.unitId,
                    u.unitNames
                FROM metadata_commodities m
                JOIN metadata_units u ON m.unitId = u.unitId
                WHERE m.commodityCode = ?
            """, conn, params=(commodity_code,))
            
            if info.empty:
                raise ValueError(f"No commodity found with code {commodity_code}")
            
            return {
                'commodity_code': info['commodityCode'].iloc[0],
                'commodity_name': info['commodityName'].iloc[0],
                'unit_id': info['unitId'].iloc[0],
                'unit_name': info['unitNames'].iloc[0]
            }

    def get_marketing_year_info(self, commodity_code: int) -> pd.DataFrame:
        with sqlite3.connect(self.db_path) as conn:
            my_dates = pd.read_sql("""
                SELECT 
                    marketYear,
                    marketYearStart,
                    marketYearEnd
                FROM data_releases
                WHERE commodityCode = ?
                ORDER BY marketYear
            """, conn, params=(commodity_code,))
            
            if my_dates.empty:
                raise ValueError(f"No marketing year data for commodity {commodity_code}")
                
            my_dates['marketYearStart'] = pd.to_datetime(my_dates['marketYearStart'])
            my_dates['marketYearEnd'] = pd.to_datetime(my_dates['marketYearEnd'])
            
            latest_year = my_dates['marketYear'].max()
            latest_year_data = my_dates[my_dates['marketYear'] == latest_year].iloc[0]
            
            next_year_data = pd.DataFrame({
                'marketYear': [latest_year + 1],
                'marketYearStart': [latest_year_data['marketYearStart'] + pd.offsets.DateOffset(years=1)],
                'marketYearEnd': [latest_year_data['marketYearEnd'] + pd.offsets.DateOffset(years=1)]
            })
            
            return pd.concat([my_dates, next_year_data], ignore_index=True)

    def set_marketing_years(self, start_my: int, end_my: int) -> None:
        if start_my > end_my:
            raise ValueError("Start marketing year must be <= end marketing year")
        if not all(my in self.my_dates['marketYear'].values for my in range(start_my, end_my + 1)):
            raise ValueError("Some specified marketing years not found in database")
        self.my_range = (start_my, end_my)
        logging.info(f"Set marketing year range: {self.my_range}")

    def load_data(self, commodity_code: int) -> pd.DataFrame:
        with sqlite3.connect(self.db_path) as conn:
            exports_df = pd.read_sql("""
                SELECT 
                    e.*, 
                    c.commodityName,
                    mc.countryName,
                    mc.countryDescription,
                    mc.regionId,
                    u.unitNames as unit
                FROM commodity_exports e
                JOIN metadata_commodities c ON e.commodityCode = c.commodityCode
                JOIN metadata_countries mc ON e.countryCode = mc.countryCode
                JOIN metadata_units u ON e.unitId = u.unitId
                WHERE e.commodityCode = ?
                AND e.market_year BETWEEN ? AND ?
                ORDER BY weekEndingDate
            """, conn, params=(commodity_code, self.my_range[0], self.my_range[1]))
        
        if exports_df.empty:
            logging.warning(f"No export data for commodity {commodity_code} in years {self.my_range}")
            return pd.DataFrame()
            
        exports_df['weekEndingDate'] = pd.to_datetime(exports_df['weekEndingDate'])
        
        column_mappings = {
            'current': {
                'netSales': 'currentMYNetSales',
                'totalCommitment': 'currentMYTotalCommitment',
                'outstandingSales': 'outstandingSales',
                'accumulatedExports': 'accumulatedExports',
                'weeklyExports': 'weeklyExports',
                'grossNewSales': 'grossNewSales'
            },
            'next': {
                'netSales': 'nextMYNetSales',
                'outstandingSales': 'nextMYOutstandingSales',
            }
        }

        if exports_df.groupby(['weekEndingDate', 'market_year', 'countryCode']).size().max() > 1:
            logging.warning("Duplicate entries detected in raw data")
            exports_df = exports_df.drop_duplicates(['weekEndingDate', 'market_year', 'countryCode'], keep='first')

        processed_data = []
        current_my_data = exports_df.drop(columns=list(column_mappings['next'].values()))
        for std_col, source_col in column_mappings['current'].items():
            if std_col != source_col:
                current_my_data[std_col] = current_my_data[source_col]
        current_my_data = current_my_data.drop(columns=[col for std_col, col in column_mappings['current'].items() if std_col != col])
        processed_data.append(current_my_data)
        
        next_my_data = exports_df.drop(columns=list(column_mappings['current'].values()))
        for std_col, source_col in column_mappings['next'].items():
            if std_col != source_col:
                next_my_data[std_col] = next_my_data[source_col]
        next_my_data = next_my_data.drop(columns=[col for std_col, col in column_mappings['next'].items() if std_col != col])
        next_my_data['market_year'] = next_my_data['market_year'] + 1
        processed_data.append(next_my_data)
        
        processed_data = pd.concat(processed_data, ignore_index=True)
        
        numeric_columns = list(self.metrics.keys())
        for col in numeric_columns:
            if col in processed_data.columns:
                processed_data[col] = pd.to_numeric(processed_data[col], errors='coerce')
        
        processed_data['display_units'] = self.unit_info['unit_name']
        processed_data = processed_data.merge(self.my_dates, left_on='market_year', right_on='marketYear', how='left')
        
        processed_data['weeks_into_my'] = processed_data.apply(
            lambda row: ((row['weekEndingDate'] - row['marketYearStart']).days // 7 + 1)
            if pd.notna(row['weekEndingDate']) and pd.notna(row['marketYearStart']) else None,
            axis=1
        )
        
        processed_data = processed_data.sort_values('weekEndingDate').reset_index(drop=True)
        logging.info(f"Loaded {len(processed_data)} records for commodity {commodity_code}")
        return processed_data

    def plot_weekly_metric(self, exports_df: pd.DataFrame, metric: str, country: str = None) -> go.Figure:
        if exports_df.empty:
            return go.Figure()
            
        if country and country != "All Countries":
            filtered_df = exports_df[exports_df['countryName'] == country]
            title_suffix = f" - {country}"
        else:
            filtered_df = exports_df
            title_suffix = ""
            
        weekly_data = filtered_df.groupby(['market_year', 'weekEndingDate'])[metric].sum().reset_index()
        
        fig = go.Figure()
        for year in sorted(weekly_data['market_year'].unique()):
            year_data = weekly_data[weekly_data['market_year'] == year]
            fig.add_trace(go.Bar(x=year_data['weekEndingDate'], y=year_data[metric], name=f'MY {year-1}/{year}'))
        
        units = exports_df['display_units'].iloc[0]
        fig.update_layout(
            title=f'{self.metrics[metric]} - Weekly Trend (MY {self.my_range[0]}-{self.my_range[1]}){title_suffix}',
            xaxis_title='Week Ending Date',
            yaxis_title=units,
            showlegend=True,
            height=600,
            width=1000,
            template='plotly_white',
            barmode='overlay',
            legend=dict(
                x=1.05,
                y=1,
                xanchor='left',
                yanchor='top',
                bgcolor='rgba(255,255,255,0.5)',
                bordercolor='black',
                borderwidth=1,
                font=dict(size=10),
                traceorder='normal',
                itemsizing='constant',
                itemwidth=30,
                orientation='v',
                tracegroupgap=0
            ),
            margin=dict(l=50, r=200, t=100, b=50)
        )
        return fig

    def plot_weekly_metric_country(self, exports_df: pd.DataFrame, metric: str, country: str = None) -> go.Figure:
        if exports_df.empty:
            return go.Figure()
            
        if country and country != "All Countries":
            filtered_df = exports_df[exports_df['countryName'] == country]
            title_suffix = f" - {country}"
        else:
            filtered_df = exports_df
            title_suffix = ""
            
        weekly_data = filtered_df.groupby(['market_year', 'weekEndingDate', 'countryName'])[metric].sum().reset_index()
        
        fig = go.Figure()
        for country_name in sorted(weekly_data['countryName'].unique()):
            country_data = weekly_data[weekly_data['countryName'] == country_name]
            fig.add_trace(go.Bar(x=country_data['weekEndingDate'], y=country_data[metric], name=country_name))
        
        units = exports_df['display_units'].iloc[0]
        fig.update_layout(
            title=f'{self.metrics[metric]} - Weekly Trend by Country (MY {self.my_range[0]}-{self.my_range[1]}){title_suffix}',
            xaxis_title='Week Ending Date',
            yaxis_title=units,
            showlegend=True,
            height=800,
            width=1000,
            template='plotly_white',
            barmode='stack',
            legend=dict(
                x=1.05,
                y=1,
                xanchor='left',
                yanchor='top',
                bgcolor='rgba(255,255,255,0.5)',
                bordercolor='black',
                borderwidth=1,
                font=dict(size=10),
                traceorder='normal',
                itemsizing='constant',
                itemwidth=30,
                orientation='v',
                tracegroupgap=0
            ),
            margin=dict(l=50, r=200, t=100, b=50)
        )
        return fig

    def plot_marketing_year_metric(self, exports_df: pd.DataFrame, metric: str, country: str = None) -> go.Figure:
        if exports_df.empty:
            logging.warning(f"No data to plot for {metric} MY comparison")
            return go.Figure()
            
        if country and country != "All Countries":
            filtered_df = exports_df[exports_df['countryName'] == country]
            title_suffix = f" - {country}"
        else:
            filtered_df = exports_df
            title_suffix = ""
            
        fig = go.Figure()
        max_weeks = 0
        
        for year in range(self.my_range[0], self.my_range[1] + 1):
            year_data = filtered_df[filtered_df['market_year'] == year].copy()
            if not year_data.empty and 'weeks_into_my' in year_data.columns:
                year_data_grouped = year_data.groupby(['weeks_into_my', 'market_year', 'marketYearStart'])[metric].sum().reset_index()
                if not year_data_grouped.empty:
                    max_weeks = max(max_weeks, int(year_data_grouped['weeks_into_my'].max()))
                    all_weeks = pd.DataFrame({'weeks_into_my': range(1, max_weeks + 1)})
                    year_data_complete = pd.merge(all_weeks, year_data_grouped, on='weeks_into_my', how='left').fillna(0)
                    
                    year_data_complete['market_year'] = year
                    year_data_complete['marketYearStart'] = year_data['marketYearStart'].iloc[0] if not year_data['marketYearStart'].isna().all() else pd.NaT
                    
                    year = year_data_complete['market_year'].iloc[0]
                    start_date = (year_data_complete['marketYearStart'].iloc[0].strftime('%b %d') 
                                if pd.notna(year_data_complete['marketYearStart'].iloc[0]) 
                                else 'Unknown')
                    
                    fig.add_trace(go.Scatter(
                        x=year_data_complete['weeks_into_my'],
                        y=year_data_complete[metric],
                        name=f'MY {year-1}/{year} (Start: {start_date})',
                        mode='lines'
                    ))
        
        units = exports_df['display_units'].iloc[0] if not exports_df.empty else 'Unknown'
        fig.update_layout(
            title=f'Weekly {self.metrics[metric]} - Marketing Year Comparison{title_suffix}',
            xaxis_title='Weeks into Marketing Year',
            yaxis_title=f'{units} per Week',
            showlegend=True,
            height=800,
            width=1000,
            template='plotly_white',
            xaxis=dict(tickmode='linear', dtick=4),
            legend=dict(
                x=1.05,
                y=1,
                xanchor='left',
                yanchor='top',
                bgcolor='rgba(255,255,255,0.5)',
                bordercolor='black',
                borderwidth=1,
                font=dict(size=10),
                traceorder='normal',
                itemsizing='constant',
                itemwidth=30,
                orientation='v',
                tracegroupgap=0
            ),
            margin=dict(l=50, r=200, t=100, b=50)
        )
        return fig

    def create_interactive_dashboard(self):
        analyzer = self
        
        commodity_dropdown = widgets.Dropdown(
            options=[(row['commodityName'], row['commodityCode']) for _, row in self.commodities.iterrows()],
            description='Commodity:'
        )
        
        country_dropdown = widgets.Dropdown(
            options=[("All Countries", "All Countries")] + [(row['countryName'], row['countryName']) for _, row in self.countries.iterrows()],
            description='Country:',
            value="All Countries"
        )
        
        def update_my_range(commodity_code):
            analyzer.my_dates = analyzer.get_marketing_year_info(commodity_code)
            analyzer.unit_info = analyzer.get_unit_info(commodity_code)
            years = sorted(analyzer.my_dates['marketYear'].values)
            return widgets.IntRangeSlider(
                value=[years[-5], years[-1]],
                min=min(years),
                max=max(years),
                step=1,
                description='MY Range:',
                continuous_update=False
            )
        
        metric_dropdown = widgets.Dropdown(
            options=[(v, k) for k, v in self.metrics.items()],
            description='Metric:'
        )
        
        plot_type_dropdown = widgets.Dropdown(
            options=[
                ('Weekly Trend', 'weekly'),
                ('Weekly by Country', 'country'),
                ('MY Comparison', 'my_comparison')
            ],
            description='Plot Type:'
        )
        
        @interact
        def update_plot(commodity=commodity_dropdown, 
                       country=country_dropdown,
                       years=widgets.fixed(None),
                       metric=metric_dropdown,
                       plot_type=plot_type_dropdown):
            analyzer.unit_info = analyzer.get_unit_info(commodity)
            analyzer.my_dates = analyzer.get_marketing_year_info(commodity)
            years_slider = update_my_range(commodity)
            
            @interact
            def final_plot(years=years_slider):
                analyzer.set_marketing_years(years[0], years[1])
                analyzer.exports_df = analyzer.load_data(commodity)
                
                if plot_type == 'weekly':
                    fig = analyzer.plot_weekly_metric(analyzer.exports_df, metric, country)
                elif plot_type == 'country':
                    fig = analyzer.plot_weekly_metric_country(analyzer.exports_df, metric, country)
                else:
                    fig = analyzer.plot_marketing_year_metric(analyzer.exports_df, metric, country)
                
                summary = pd.DataFrame({
                    'Metric': [analyzer.metrics[metric]],
                    'Latest Week': [analyzer.exports_df[analyzer.exports_df['weekEndingDate'] == 
                                                    analyzer.exports_df['weekEndingDate'].max()][metric].sum()],
                    'MY Total': [analyzer.exports_df[analyzer.exports_df['market_year'] == 
                                                  analyzer.exports_df['market_year'].max()][metric].sum()],
                    'Units': [analyzer.exports_df['display_units'].iloc[0]]
                })
                
                display(HTML(f"<h3>Commodity: {analyzer.unit_info['commodity_name']}</h3>"))
                display(summary)
                fig.show()

if __name__ == "__main__":
    try:
        analyzer = ExportAnalyzer()
        analyzer.create_interactive_dashboard()
    except Exception as e:
        logging.error(f"Error in analysis: {str(e)}")