In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import plotly.express as px
import plotly.graph_objects as go
import warnings
warnings.filterwarnings('ignore')

# Set style for plots
plt.style.use('ggplot')
sns.set_palette('viridis')
plt.rcParams['figure.figsize'] = (12, 8)

# Load the dataset
print("Loading sales data...")
sales_data = pd.read_csv('../sales_data.csv')

# Display basic information
print("\n--- Dataset Overview ---")
print(f"Dataset shape: {sales_data.shape}")
print(f"Number of unique distributors: {sales_data['distributor_id'].nunique()}")
print(f"Number of unique SKUs: {sales_data['sku'].nunique()}")
print(f"Time range: {sales_data['year'].min()}-{sales_data['quarter'].min()} to {sales_data['year'].max()}-{sales_data['quarter'].max()}")

# Display first few rows
print("\n--- Sample Data ---")
sales_data.head()

# Check column types and missing values
print("\n--- Data Types and Missing Values ---")
data_types = pd.DataFrame({
    'Data Type': sales_data.dtypes,
    'Missing Values': sales_data.isnull().sum(),
    'Missing Percentage': round(sales_data.isnull().sum() / len(sales_data) * 100, 2)
})
data_types

# Basic statistics for numerical columns
print("\n--- Numerical Features Statistics ---")
sales_data.describe()

# Distribution of sales
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
sns.histplot(sales_data['sales'], kde=True)
plt.title('Distribution of Sales')
plt.xlabel('Sales Value')

plt.subplot(1, 2, 2)
sns.histplot(np.log1p(sales_data['sales']), kde=True)
plt.title('Distribution of Log Sales')
plt.xlabel('Log(Sales + 1)')
plt.tight_layout()
plt.show()

# Category-based analysis
print("\n--- Category Analysis ---")
category_sales = sales_data.groupby('category')['sales'].agg(['mean', 'median', 'sum', 'count']).sort_values('sum', ascending=False)
category_sales

# Plot category sales
plt.figure(figsize=(14, 8))
sns.barplot(x=category_sales.index, y=category_sales['sum'], order=category_sales.index)
plt.title('Total Sales by Product Category')
plt.xlabel('Category')
plt.ylabel('Total Sales')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Industry-based analysis
industry_sales = sales_data.groupby('industry')['sales'].agg(['mean', 'median', 'sum', 'count']).sort_values('sum', ascending=False)
plt.figure(figsize=(14, 6))
sns.barplot(x=industry_sales.index, y=industry_sales['sum'], order=industry_sales.index)
plt.title('Total Sales by Industry')
plt.xlabel('Industry')
plt.ylabel('Total Sales')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Temporal patterns
print("\n--- Temporal Patterns ---")
sales_by_time = sales_data.groupby(['year', 'quarter'])['sales'].sum().reset_index()
sales_by_time['year_quarter'] = sales_by_time['year'].astype(str) + '-Q' + sales_by_time['quarter'].astype(str)
sales_by_time = sales_by_time.sort_values(['year', 'quarter'])

plt.figure(figsize=(16, 6))
plt.plot(sales_by_time['year_quarter'], sales_by_time['sales'], marker='o', linestyle='-')
plt.title('Total Sales Over Time')
plt.xlabel('Year-Quarter')
plt.ylabel('Total Sales')
plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
plt.show()

# Quarterly sales patterns
quarterly_sales = sales_data.groupby('quarter')['sales'].sum()
plt.figure(figsize=(10, 6))
sns.barplot(x=quarterly_sales.index, y=quarterly_sales.values)
plt.title('Sales by Quarter (All Years Combined)')
plt.xlabel('Quarter')
plt.ylabel('Total Sales')
plt.grid(True, axis='y')
plt.show()

# Festival impact analysis
print("\n--- Festival Impact Analysis ---")
festival_columns = ['is_diwali', 'is_ganesh_chaturthi', 'is_gudi_padwa', 'is_eid', 
                    'is_akshay_tritiya', 'is_dussehra_navratri', 'is_onam', 'is_christmas']

festival_impact = pd.DataFrame()
for festival in festival_columns:
    festival_impact[festival] = [
        sales_data[sales_data[festival] == 1]['sales'].mean(),
        sales_data[sales_data[festival] == 0]['sales'].mean(),
        sales_data[sales_data[festival] == 1]['sales'].mean() / sales_data[sales_data[festival] == 0]['sales'].mean() - 1
    ]
    
festival_impact.index = ['With Festival', 'Without Festival', 'Impact (%)']
festival_impact = festival_impact.transpose()
festival_impact['Impact (%)'] = festival_impact['Impact (%)'] * 100
festival_impact = festival_impact.sort_values('Impact (%)', ascending=False)

plt.figure(figsize=(14, 8))
sns.barplot(x=festival_impact.index, y=festival_impact['Impact (%)'], palette='coolwarm')
plt.title('Festival Impact on Sales (%)')
plt.xlabel('Festival')
plt.ylabel('Impact on Average Sales (%)')
plt.xticks(rotation=45)
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

# Movement category analysis
movement_cat = sales_data.groupby('movement_category')['sales'].agg(['mean', 'median', 'count']).reset_index()
plt.figure(figsize=(10, 6))
sns.barplot(x='movement_category', y='mean', data=movement_cat, order=['Slow Moving', 'Medium', 'Fast Moving'])
plt.title('Average Sales by Movement Category')
plt.xlabel('Movement Category')
plt.ylabel('Average Sales')
plt.grid(True, axis='y')
plt.show()

# Previous quarter correlation
plt.figure(figsize=(10, 6))
sns.scatterplot(x='prev_quarter_sales', y='sales', data=sales_data, alpha=0.6, hue='movement_category')
plt.title('Current Quarter Sales vs Previous Quarter Sales')
plt.xlabel('Previous Quarter Sales')
plt.ylabel('Current Quarter Sales')
plt.grid(True)
plt.show()

print(f"Correlation between current and previous quarter sales: {sales_data['sales'].corr(sales_data['prev_quarter_sales']):.3f}")

# Analyzing top distributors
top_distributors = sales_data.groupby('distributor_id')['sales'].sum().nlargest(10).reset_index()
plt.figure(figsize=(12, 6))
sns.barplot(x='distributor_id', y='sales', data=top_distributors)
plt.title('Top 10 Distributors by Total Sales')
plt.xlabel('Distributor ID')
plt.ylabel('Total Sales')
plt.xticks(rotation=45)
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

# SKU popularity
top_skus = sales_data.groupby('sku')['sales'].sum().nlargest(10).reset_index()
plt.figure(figsize=(14, 6))
sns.barplot(x='sku', y='sales', data=top_skus)
plt.title('Top 10 SKUs by Total Sales')
plt.xlabel('SKU')
plt.ylabel('Total Sales')
plt.xticks(rotation=45)
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

# Insights for TFT model
print("\n--- Insights for TFT Model ---")
print("1. Key static categorical variables: 'distributor_id', 'industry', 'sku', 'category', 'movement_category'")
print("2. Key time-varying variables: 'sales', 'prev_quarter_sales', 'total_quarter_sales', 'avg_quarterly_sales'")
print("3. Festival flags are important seasonal indicators")
print("4. Time index is available as 'time_idx' for temporal ordering")
print("5. Target variable: 'sales' for next period")

# Summary
print("\n--- Summary of Findings ---")
print("• Sales data shows distinct patterns by product category, industry, and time")
print("• Festivals have measurable impact on sales performance")
print("• Previous quarter sales are strong predictors of current sales")
print("• The data includes necessary variables for TFT modeling")
print("• Next steps: Feature engineering and preparing data for TFT format")